diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index b367382c834aa3c1a784cef8e4467475bd35e592..12475db53b048b3d1dfa93faa98fb7f42eb446ea 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -47,6 +47,8 @@ from .container.traversal import ( rec_multimap_array_container, mapped_over_array_containers, multimapped_over_array_containers, + keyed_map_array_container, + rec_keyed_map_array_container, thaw, freeze, from_numpy, to_numpy) @@ -76,6 +78,8 @@ __all__ = ( "rec_map_array_container", "rec_multimap_array_container", "mapped_over_array_containers", "multimapped_over_array_containers", + "keyed_map_array_container", + "rec_keyed_map_array_container", "thaw", "freeze", "from_numpy", "to_numpy", diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 03587b51c5f2ce11809586b27b47cb7d649186ea..f8dc5f2567b94128c60a3b8d7f1a561cb8482855 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -48,7 +48,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Union, Tuple from functools import update_wrapper, partial, singledispatch import numpy as np @@ -232,6 +232,67 @@ def multimapped_over_array_containers( update_wrapper(wrapper, f) return wrapper + +def keyed_map_array_container( + f: Callable[[Union[str, int], Any], Any], + ary: ArrayContainerT) -> ArrayContainerT: + r"""Applies *f* to all components of an :class:`ArrayContainer`. + + Works similar to :func:`map_array_container`, but *f* also takes an + identifier of the array in the container *ary*. + + For a recursive version, see :func:`rec_keyed_map_array_container`. + + :param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s, + or an instance of a base array type. + """ + if is_array_container(ary): + return deserialize_container(ary, [ + (key, f(key, subary)) for key, subary in serialize_container(ary) + ]) + else: + raise ValueError("Not an array-container, i.e. unknown key to pass.") + + +def _keyed_map_array_container_impl( + f: Callable[[Any], Any], + ary: ArrayContainerT, *, + leaf_cls: Optional[type] = None, + recursive: bool = False) -> ArrayContainerT: + """Helper for :func:`rec_keyed_map_array_container`. + + :param leaf_cls: class on which we call *f* directly. This is mostly + useful in the recursive setting, where it can stop the recursion on + specific container classes. By default, the recursion is stopped when + a non-:class:`ArrayContainer` class is encountered. + """ + def rec(keys: Tuple[Union[str, int], ...], + _ary: ArrayContainerT) -> ArrayContainerT: + if type(_ary) is leaf_cls: # type(ary) is never None + return f(keys, _ary) + elif is_array_container(_ary): + return deserialize_container(_ary, [ + (key, frec(keys+(key,), subary)) + for key, subary in serialize_container(_ary) + ]) + else: + return f(keys, _ary) + + frec = rec if recursive else f + return rec((), ary) + + +def rec_keyed_map_array_container( + f: Callable[[Tuple[Union[str, int], ...], Any], Any], + ary: ArrayContainerT) -> ArrayContainerT: + """ + Works similar to :func:`rec_map_array_container`, except that *f* also + takes in a traversal path to the leaf array. The traversal path argument is + passed in as a tuple of identifiers of the arrays traversed before reaching + the current array. + """ + return _keyed_map_array_container_impl(f, ary, recursive=True) + # }}}