diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 76242ef45afa83aca5d734ac39a59c77e8f58ea8..61203029ca484cba9918afc54a4c547b9ecf8905 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -53,6 +53,8 @@ from .container.traversal import ( rec_multimap_array_container, mapped_over_array_containers, multimapped_over_array_containers, + map_reduce_array_container, + multimap_reduce_array_container, rec_map_reduce_array_container, rec_multimap_reduce_array_container, thaw, freeze, @@ -87,6 +89,7 @@ __all__ = ( "rec_map_array_container", "rec_multimap_array_container", "mapped_over_array_containers", "multimapped_over_array_containers", + "map_reduce_array_container", "multimap_reduce_array_container", "rec_map_reduce_array_container", "rec_multimap_reduce_array_container", "thaw", "freeze", "from_numpy", "to_numpy", diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 31b0545ad8663e1bee83de2be7ed15feb5379a95..53c2986da62308d981fe968a35ca48f35d1d182b 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -8,6 +8,8 @@ .. autofunction:: rec_map_array_container .. autofunction:: rec_multimap_array_container +.. autofunction:: map_reduce_array_container +.. autofunction:: multimap_reduce_array_container .. autofunction:: rec_map_reduce_array_container .. autofunction:: rec_multimap_reduce_array_container @@ -294,6 +296,48 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], # {{{ array container reductions +def map_reduce_array_container( + reduce_func: Callable[[Iterable[Any]], Any], + map_func: Callable[[Any], Any], + ary: ArrayOrContainerT) -> Any: + """Perform a map-reduce over array containers. + + :param reduce_func: callable used to reduce over the components of *ary* + if *ary* is an :class:`~arraycontext.ArrayContainer`. + :param map_func: callable used to map a single array of type + :class:`arraycontext.ArrayContext.array_types`. Returns an array of the + same type or a scalar. + """ + if is_array_container(ary): + return reduce_func([ + map_func(subary) for _, subary in serialize_container(ary) + ]) + else: + return map_func(ary) + + +def multimap_reduce_array_container( + reduce_func: Callable[[Iterable[Any]], Any], + map_func: Callable[..., Any], + *args: Any) -> Any: + r"""Perform a map-reduce over multiple array containers. + + :param reduce_func: callable used to reduce over the components of any + :class:`~arraycontext.ArrayContainer`\ s in *\*args*. + :param map_func: callable used to map a single array of type + :class:`arraycontext.ArrayContext.array_types`. Returns an array of the + same type or a scalar. + """ + # NOTE: this wrapper matches the signature of `deserialize_container` + # to make plugging into `_multimap_array_container_impl` easier + def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any: + return reduce_func([subary for _, subary in iterable]) + + return _multimap_array_container_impl( + map_func, *args, + reduce_func=_reduce_wrapper, leaf_cls=None, recursive=False) + + def rec_map_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[[Any], Any],