From e309ea20beee1527a76764cf087755ee9d8bd449 Mon Sep 17 00:00:00 2001 From: Matthew Smith <mjsmith6@illinois.edu> Date: Wed, 25 Aug 2021 16:46:35 -0500 Subject: [PATCH] split ArrayContainerT into ContainerT and ArrayOrContainerT --- arraycontext/container/__init__.py | 12 ++++++--- arraycontext/container/traversal.py | 40 ++++++++++++++--------------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 5142c05..1046f99 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -3,11 +3,16 @@ """ .. currentmodule:: arraycontext -.. class:: ArrayContainerT - :canonical: arraycontext.container.ArrayContainerT +.. class:: ContainerT + :canonical: arraycontext.container.ContainerT :class:`~typing.TypeVar` for array container-like objects. +.. class:: ArrayOrContainerT + :canonical: arraycontext.container.ArrayOrContainerT + + :class:`~typing.TypeVar` for arrays or array container-like objects. + .. autoclass:: ArrayContainer Serialization/deserialization @@ -52,7 +57,8 @@ from arraycontext.context import ArrayContext from typing import Any, Iterable, Tuple, TypeVar, Optional import numpy as np -ArrayContainerT = TypeVar("ArrayContainerT") +ContainerT = TypeVar("ContainerT") +ArrayOrContainerT = TypeVar("ArrayOrContainerT") # {{{ ArrayContainer diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index cfdea4c..ed6dc0c 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -58,7 +58,7 @@ import numpy as np from arraycontext.context import ArrayContext from arraycontext.container import ( - ArrayContainerT, is_array_container, + ContainerT, ArrayOrContainerT, is_array_container, serialize_container, deserialize_container) @@ -66,9 +66,9 @@ from arraycontext.container import ( def _map_array_container_impl( f: Callable[[Any], Any], - ary: ArrayContainerT, *, + ary: ArrayOrContainerT, *, leaf_cls: Optional[type] = None, - recursive: bool = False) -> ArrayContainerT: + recursive: bool = False) -> ArrayOrContainerT: """Helper for :func:`rec_map_array_container`. :param leaf_cls: class on which we call *f* directly. This is mostly @@ -76,7 +76,7 @@ def _map_array_container_impl( specific container classes. By default, the recursion is stopped when a non-:class:`ArrayContainer` class is encountered. """ - def rec(_ary: ArrayContainerT) -> ArrayContainerT: + def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: if type(_ary) is leaf_cls: # type(ary) is never None return f(_ary) elif is_array_container(_ary): @@ -93,9 +93,9 @@ def _map_array_container_impl( def _multimap_array_container_impl( f: Callable[..., Any], *args: Any, - reduce_func: Callable[[Any, Iterable[Tuple[Any, Any]]], Any] = None, + reduce_func: Callable[[ContainerT, Iterable[Tuple[Any, Any]]], Any] = None, leaf_cls: Optional[type] = None, - recursive: bool = False) -> ArrayContainerT: + recursive: bool = False) -> ArrayOrContainerT: """Helper for :func:`rec_multimap_array_container`. :param leaf_cls: class on which we call *f* directly. This is mostly @@ -142,13 +142,13 @@ def _multimap_array_container_impl( if len(container_indices) == 1 and reduce_func is None: # NOTE: if we just have one ArrayContainer in args, passing it through # _map_array_container_impl should be faster - def wrapper(ary: ArrayContainerT) -> ArrayContainerT: + def wrapper(ary: ContainerT) -> ContainerT: new_args = list(args) new_args[container_indices[0]] = ary return f(*new_args) update_wrapper(wrapper, f) - template_ary: ArrayContainerT = args[container_indices[0]] + template_ary: ContainerT = args[container_indices[0]] return _map_array_container_impl( wrapper, template_ary, leaf_cls=leaf_cls, recursive=recursive) @@ -165,7 +165,7 @@ def _multimap_array_container_impl( def map_array_container( f: Callable[[Any], Any], - ary: ArrayContainerT) -> ArrayContainerT: + ary: ArrayOrContainerT) -> ArrayOrContainerT: r"""Applies *f* to all components of an :class:`ArrayContainer`. Works similarly to :func:`~pytools.obj_array.obj_array_vectorize`, but @@ -202,7 +202,7 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any: def rec_map_array_container( f: Callable[[Any], Any], - ary: ArrayContainerT) -> ArrayContainerT: + ary: ArrayOrContainerT) -> ArrayOrContainerT: r"""Applies *f* recursively to an :class:`ArrayContainer`. For a non-recursive version see :func:`map_array_container`. @@ -214,7 +214,7 @@ def rec_map_array_container( def mapped_over_array_containers( - f: Callable[[Any], Any]) -> Callable[[ArrayContainerT], ArrayContainerT]: + f: Callable[[Any], Any]) -> Callable[[ArrayOrContainerT], ArrayOrContainerT]: """Decorator around :func:`rec_map_array_container`.""" wrapper = partial(rec_map_array_container, f) update_wrapper(wrapper, f) @@ -249,7 +249,7 @@ def multimapped_over_array_containers( # {{{ keyed array container traversal def keyed_map_array_container(f: Callable[[Any, Any], Any], - ary: ArrayContainerT) -> ArrayContainerT: + ary: ArrayOrContainerT) -> ArrayOrContainerT: r"""Applies *f* to all components of an :class:`ArrayContainer`. Works similarly to :func:`map_array_container`, but *f* also takes an @@ -269,7 +269,7 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any], def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], - ary: ArrayContainerT) -> ArrayContainerT: + ary: ArrayOrContainerT) -> ArrayOrContainerT: """ Works similarly to :func:`rec_map_array_container`, except that *f* also takes in a traversal path to the leaf array. The traversal path argument is @@ -278,7 +278,7 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], """ def rec(keys: Tuple[Union[str, int], ...], - _ary: ArrayContainerT) -> ArrayContainerT: + _ary: ArrayOrContainerT) -> ArrayOrContainerT: if is_array_container(_ary): return deserialize_container(_ary, [ (key, rec(keys + (key,), subary)) @@ -297,7 +297,7 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], def rec_map_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[[Any], Any], - ary: ArrayContainerT) -> Any: + ary: ArrayOrContainerT) -> Any: """Perform a map-reduce over array containers recursively. :param reduce_func: callable used to reduce over the components of the @@ -307,7 +307,7 @@ def rec_map_reduce_array_container( type :class:`arraycontext.ArrayContext.array_types` and returns an array of the same type or a scalar. """ - def rec(_ary: ArrayContainerT) -> ArrayContainerT: + def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: if is_array_container(_ary): return reduce_func([ rec(subary) for _, subary in serialize_container(_ary) @@ -333,7 +333,7 @@ def rec_multimap_reduce_array_container( """ # NOTE: this wrapper matches the signature of `deserialize_container` # to make plugging into `_multimap_array_container_impl` easier - def _reduce_wrapper(ary: Any, iterable: Iterable[Tuple[Any, Any]]) -> Any: + def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any: return reduce_func([subary for _, subary in iterable]) return _multimap_array_container_impl( @@ -347,8 +347,8 @@ def rec_multimap_reduce_array_container( @singledispatch def freeze( - ary: ArrayContainerT, - actx: Optional[ArrayContext] = None) -> ArrayContainerT: + ary: ArrayOrContainerT, + actx: Optional[ArrayContext] = None) -> ArrayOrContainerT: r"""Freezes recursively by going through all components of the :class:`ArrayContainer` *ary*. @@ -372,7 +372,7 @@ def freeze( @singledispatch -def thaw(ary: ArrayContainerT, actx: ArrayContext) -> ArrayContainerT: +def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT: r"""Thaws recursively by going through all components of the :class:`ArrayContainer` *ary*. -- GitLab