From 12616418b6efd045a6fb66581f2a86e5275124d8 Mon Sep 17 00:00:00 2001 From: Alex Fikl <alexfikl@gmail.com> Date: Fri, 16 Jul 2021 17:59:54 -0500 Subject: [PATCH] Add generic container reductions (#62) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add recursive container reductions * extend reductions to array containers * make rec_reduce_array_container more specific * rename make_container to process_container * rename rec_reduce_container to rec_map_reduce_container * fix typo in comment Co-authored-by: Andreas Klöckner <inform@tiker.net> Co-authored-by: Andreas Klöckner <inform@tiker.net> --- arraycontext/__init__.py | 3 + arraycontext/container/__init__.py | 4 +- arraycontext/container/traversal.py | 82 +++++++++++++++++++++--- arraycontext/impl/pyopencl/fake_numpy.py | 57 ++++++++-------- arraycontext/impl/pytato/fake_numpy.py | 36 +++++++---- test/test_arraycontext.py | 16 +++-- 6 files changed, 143 insertions(+), 55 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index a338059..aafcfd8 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -52,6 +52,8 @@ from .container.traversal import ( rec_multimap_array_container, mapped_over_array_containers, multimapped_over_array_containers, + rec_map_reduce_array_container, + rec_multimap_reduce_array_container, thaw, freeze, from_numpy, to_numpy) @@ -83,6 +85,7 @@ __all__ = ( "rec_map_array_container", "rec_multimap_array_container", "mapped_over_array_containers", "multimapped_over_array_containers", + "rec_map_reduce_array_container", "rec_multimap_reduce_array_container", "thaw", "freeze", "from_numpy", "to_numpy", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 44b7e9e..5142c05 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -139,7 +139,7 @@ def is_array_container_type(cls: type) -> bool: return ( cls is ArrayContainer or (serialize_container.dispatch(cls) - is not serialize_container.__wrapped__)) # type: ignore + is not serialize_container.__wrapped__)) # type:ignore[attr-defined] def is_array_container(ary: Any) -> bool: @@ -148,7 +148,7 @@ def is_array_container(ary: Any) -> bool: :func:`serialize_container`. """ return (serialize_container.dispatch(ary.__class__) - is not serialize_container.__wrapped__) # type: ignore + is not serialize_container.__wrapped__) # type:ignore[attr-defined] @singledispatch diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 70babe6..cfdea4c 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -8,6 +8,9 @@ .. autofunction:: rec_map_array_container .. autofunction:: rec_multimap_array_container +.. autofunction:: rec_map_reduce_array_container +.. autofunction:: rec_multimap_reduce_array_container + Traversing decorators ~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: mapped_over_array_containers @@ -48,7 +51,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Callable, List, Optional, Union, Tuple +from typing import Any, Callable, Iterable, List, Optional, Union, Tuple from functools import update_wrapper, partial, singledispatch import numpy as np @@ -59,7 +62,7 @@ from arraycontext.container import ( serialize_container, deserialize_container) -# {{{ array container traversal +# {{{ array container traversal helpers def _map_array_container_impl( f: Callable[[Any], Any], @@ -78,8 +81,8 @@ def _map_array_container_impl( return f(_ary) elif is_array_container(_ary): return deserialize_container(_ary, [ - (key, frec(subary)) for key, subary in serialize_container(_ary) - ]) + (key, frec(subary)) for key, subary in serialize_container(_ary) + ]) else: return f(_ary) @@ -90,6 +93,7 @@ def _map_array_container_impl( def _multimap_array_container_impl( f: Callable[..., Any], *args: Any, + reduce_func: Callable[[Any, Iterable[Tuple[Any, Any]]], Any] = None, leaf_cls: Optional[type] = None, recursive: bool = False) -> ArrayContainerT: """Helper for :func:`rec_multimap_array_container`. @@ -124,9 +128,9 @@ def _multimap_array_container_impl( new_args[i] = subary - result.append((key, frec(*new_args))) # type: ignore + result.append((key, frec(*new_args))) # type: ignore[operator] - return deserialize_container(template_ary, result) + return process_container(template_ary, result) # type: ignore[operator] container_indices: List[int] = [ i for i, arg in enumerate(args) @@ -135,7 +139,7 @@ def _multimap_array_container_impl( if not container_indices: return f(*args) - if len(container_indices) == 1: + if len(container_indices) == 1 and reduce_func is None: # NOTE: if we just have one ArrayContainer in args, passing it through # _map_array_container_impl should be faster def wrapper(ary: ArrayContainerT) -> ArrayContainerT: @@ -149,9 +153,15 @@ def _multimap_array_container_impl( wrapper, template_ary, leaf_cls=leaf_cls, recursive=recursive) + process_container = deserialize_container if reduce_func is None else reduce_func frec = rec if recursive else f + return rec(*args) +# }}} + + +# {{{ array container traversal def map_array_container( f: Callable[[Any], Any], @@ -233,6 +243,10 @@ def multimapped_over_array_containers( update_wrapper(wrapper, f) return wrapper +# }}} + + +# {{{ keyed array container traversal def keyed_map_array_container(f: Callable[[Any, Any], Any], ary: ArrayContainerT) -> ArrayContainerT: @@ -266,9 +280,8 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], def rec(keys: Tuple[Union[str, int], ...], _ary: ArrayContainerT) -> ArrayContainerT: if is_array_container(_ary): - return deserialize_container(_ary, [ - (key, rec(keys+(key,), subary)) + (key, rec(keys + (key,), subary)) for key, subary in serialize_container(_ary) ]) else: @@ -279,6 +292,57 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], # }}} +# {{{ array container reductions + +def rec_map_reduce_array_container( + reduce_func: Callable[[Iterable[Any]], Any], + map_func: Callable[[Any], Any], + ary: ArrayContainerT) -> Any: + """Perform a map-reduce over array containers recursively. + + :param reduce_func: callable used to reduce over the components of the + :class:`~arraycontext.ArrayContainer`. + :param map_func: callable used to map a single component of the + :class:`~arraycontext.ArrayContainer`. The callable takes arrays of + type :class:`arraycontext.ArrayContext.array_types` and returns an + array of the same type or a scalar. + """ + def rec(_ary: ArrayContainerT) -> ArrayContainerT: + if is_array_container(_ary): + return reduce_func([ + rec(subary) for _, subary in serialize_container(_ary) + ]) + else: + return map_func(_ary) + + return rec(ary) + + +def rec_multimap_reduce_array_container( + reduce_func: Callable[[Iterable[Any]], Any], + map_func: Callable[..., Any], + *args: Any) -> Any: + """Perform a map-reduce over multiple array containers recursively. + + :param reduce_func: callable used to reduce over the components of the + :class:`~arraycontext.ArrayContainer`. + :param map_func: callable used to map a single component of the + :class:`~arraycontext.ArrayContainer`. The callable takes arrays of + type :class:`arraycontext.ArrayContext.array_types` and returns an + array of the same type or a scalar. + """ + # NOTE: this wrapper matches the signature of `deserialize_container` + # to make plugging into `_multimap_array_container_impl` easier + def _reduce_wrapper(ary: Any, iterable: Iterable[Tuple[Any, Any]]) -> Any: + return reduce_func([subary for _, subary in iterable]) + + return _multimap_array_container_impl( + map_func, *args, + reduce_func=_reduce_wrapper, leaf_cls=None, recursive=True) + +# }}} + + # {{{ freeze/thaw @singledispatch diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index bd9eb08..01f80f3 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -26,13 +26,15 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from functools import partial +from functools import partial, reduce import operator from arraycontext.fake_numpy import \ BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace -from arraycontext.container.traversal import (rec_multimap_array_container, - rec_map_array_container) +from arraycontext.container.traversal import ( + rec_multimap_array_container, rec_map_array_container, + rec_map_reduce_array_container, + ) try: import pyopencl as cl # noqa: F401 @@ -104,24 +106,35 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): return rec_multimap_array_container(where_inner, criterion, then, else_) def sum(self, a, dtype=None): - result = cl_array.sum(a, dtype=dtype, queue=self._array_context.queue) + result = rec_map_reduce_array_container( + sum, + partial(cl_array.sum, dtype=dtype, queue=self._array_context.queue), + a) + if not self._array_context._force_device_scalars: result = result.get()[()] - return result def min(self, a): - result = cl_array.min(a, queue=self._array_context.queue) + queue = self._array_context.queue + result = rec_map_reduce_array_container( + partial(reduce, partial(cl_array.minimum, queue=queue)), + partial(cl_array.min, queue=queue), + a) + if not self._array_context._force_device_scalars: result = result.get()[()] - return result def max(self, a): - result = cl_array.max(a, queue=self._array_context.queue) + queue = self._array_context.queue + result = rec_map_reduce_array_container( + partial(reduce, partial(cl_array.maximum, queue=queue)), + partial(cl_array.max, queue=queue), + a) + if not self._array_context._force_device_scalars: result = result.get()[()] - return result def stack(self, arrays, axis=0): @@ -163,25 +176,15 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace): return rec_map_array_container(_rec_ravel, a) def vdot(self, x, y, dtype=None): - import pyopencl.array as cl_array - from arraycontext import is_array_container, serialize_container - - def _rec_vdot(xi, yi): - if is_array_container(xi): - assert type(xi) == type(yi) - return sum(_rec_vdot(subxi, subyi) - for (_, subxi), (_, subyi) in zip( - serialize_container(xi), serialize_container(yi) - )) - else: - result = cl_array.vdot(xi, yi, - dtype=dtype, queue=self._array_context.queue) - if not self._array_context._force_device_scalars: - result = result.get()[()] - - return result + from arraycontext import rec_multimap_reduce_array_container + result = rec_multimap_reduce_array_container( + sum, + partial(cl_array.vdot, dtype=dtype, queue=self._array_context.queue), + x, y) - return _rec_vdot(x, y) + if not self._array_context._force_device_scalars: + result = result.get()[()] + return result # }}} diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 9793dde..f17a4ab 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -21,11 +21,15 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - -from arraycontext.fake_numpy import \ - BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace -from arraycontext.container.traversal import \ - rec_multimap_array_container, rec_map_array_container +from functools import partial, reduce + +from arraycontext.fake_numpy import ( + BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, + ) +from arraycontext.container.traversal import ( + rec_multimap_array_container, rec_map_array_container, + rec_map_reduce_array_container, + ) import pytato as pt @@ -82,20 +86,26 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace): return rec_multimap_array_container(pt.where, criterion, then, else_) def sum(self, a, dtype=None): - if dtype not in [a.dtype, None]: - raise NotImplementedError - return pt.sum(a) + def _pt_sum(ary): + if dtype not in [ary.dtype, None]: + raise NotImplementedError + + return pt.sum(ary) + + return rec_map_reduce_array_container(sum, _pt_sum, a) def min(self, a): - return pt.amin(a) + return rec_map_reduce_array_container( + partial(reduce, pt.minimum), pt.amin, a) def max(self, a): - return pt.amax(a) + return rec_map_reduce_array_container( + partial(reduce, pt.maximum), pt.amax, a) def stack(self, arrays, axis=0): - return rec_multimap_array_container(lambda *args: pt.stack(arrays=args, - axis=axis), - *arrays) + return rec_multimap_array_container( + lambda *args: pt.stack(arrays=args, axis=axis), + *arrays) # {{{ relational operators diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index c9c26e2..5bb9fb7 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -253,19 +253,27 @@ def assert_close_to_numpy_in_containers(actx, op, args): # {{{ np.function same as numpy @pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [ - ("sin", 1, np.float64), - ("sin", 1, np.complex128), - ("exp", 1, np.float64), + # float only ("arctan2", 2, np.float64), ("minimum", 2, np.float64), ("maximum", 2, np.float64), ("where", 3, np.float64), + ("min", 1, np.float64), + ("max", 1, np.float64), + + # float + complex + ("sin", 1, np.float64), + ("sin", 1, np.complex128), + ("exp", 1, np.float64), + ("exp", 1, np.complex128), ("conj", 1, np.float64), ("conj", 1, np.complex128), ("vdot", 2, np.float64), ("vdot", 2, np.complex128), ("abs", 1, np.float64), ("abs", 1, np.complex128), + ("sum", 1, np.float64), + ("sum", 1, np.complex64), ]) def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype): actx = actx_factory() @@ -494,7 +502,7 @@ def test_dof_array_arithmetic_same_as_numpy(actx_factory): # {{{ reductions same as numpy @pytest.mark.parametrize("op", ["sum", "min", "max"]) -def test_dof_array_reductions_same_as_numpy(actx_factory, op): +def test_reductions_same_as_numpy(actx_factory, op): actx = actx_factory() ary = np.random.randn(3000) -- GitLab