diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index b2abc74071ae63eb4ed246921f2a9c9e41c93cef..31b0545ad8663e1bee83de2be7ed15feb5379a95 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -300,12 +300,40 @@ def rec_map_reduce_array_container( ary: ArrayOrContainerT) -> Any: """Perform a map-reduce over array containers recursively. - :param reduce_func: callable used to reduce over the components of the - :class:`~arraycontext.ArrayContainer`. - :param map_func: callable used to map a single component of the - :class:`~arraycontext.ArrayContainer`. The callable takes arrays of - type :class:`arraycontext.ArrayContext.array_types` and returns an - array of the same type or a scalar. + :param reduce_func: callable used to reduce over the components of *ary* + (and those of its sub-containers) if *ary* is a + :class:`~arraycontext.ArrayContainer`. Must be associative. + :param map_func: callable used to map a single array of type + :class:`arraycontext.ArrayContext.array_types`. Returns an array of the + same type or a scalar. + + .. note:: + + The traversal order is unspecified. *reduce_func* must be associative in + order to guarantee a sensible result. This is because *reduce_func* may be + called on subsets of the component arrays, and then again (potentially + multiple times) on the results. As an example, consider a container made up + of two sub-containers, *subcontainer0* and *subcontainer1*, that each + contain two component arrays, *array0* and *array1*. The same result must be + computed whether traversing recursively:: + + reduce_func([ + reduce_func([ + map_func(subcontainer0.array0), + map_func(subcontainer0.array1)]), + reduce_func([ + map_func(subcontainer1.array0), + map_func(subcontainer1.array1)])]) + + reducing all of the arrays at once:: + + reduce_func([ + map_func(subcontainer0.array0), + map_func(subcontainer0.array1), + map_func(subcontainer1.array0), + map_func(subcontainer1.array1)]) + + or any other such traversal. """ def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: if is_array_container(_ary): @@ -322,14 +350,20 @@ def rec_multimap_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[..., Any], *args: Any) -> Any: - """Perform a map-reduce over multiple array containers recursively. - - :param reduce_func: callable used to reduce over the components of the - :class:`~arraycontext.ArrayContainer`. - :param map_func: callable used to map a single component of the - :class:`~arraycontext.ArrayContainer`. The callable takes arrays of - type :class:`arraycontext.ArrayContext.array_types` and returns an - array of the same type or a scalar. + r"""Perform a map-reduce over multiple array containers recursively. + + :param reduce_func: callable used to reduce over the components of any + :class:`~arraycontext.ArrayContainer`\ s in *\*args* (and those of their + sub-containers). Must be associative. + :param map_func: callable used to map a single array of type + :class:`arraycontext.ArrayContext.array_types`. Returns an array of the + same type or a scalar. + + .. note:: + + The traversal order is unspecified. *reduce_func* must be associative in + order to guarantee a sensible result. See + :func:`rec_map_reduce_array_container` for additional details. """ # NOTE: this wrapper matches the signature of `deserialize_container` # to make plugging into `_multimap_array_container_impl` easier