From 12616418b6efd045a6fb66581f2a86e5275124d8 Mon Sep 17 00:00:00 2001
From: Alex Fikl <alexfikl@gmail.com>
Date: Fri, 16 Jul 2021 17:59:54 -0500
Subject: [PATCH] Add generic container reductions (#62)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* add recursive container reductions

* extend reductions to array containers

* make rec_reduce_array_container more specific

* rename make_container to process_container

* rename rec_reduce_container to rec_map_reduce_container

* fix typo in comment

Co-authored-by: Andreas Klöckner <inform@tiker.net>

Co-authored-by: Andreas Klöckner <inform@tiker.net>
---
 arraycontext/__init__.py                 |  3 +
 arraycontext/container/__init__.py       |  4 +-
 arraycontext/container/traversal.py      | 82 +++++++++++++++++++++---
 arraycontext/impl/pyopencl/fake_numpy.py | 57 ++++++++--------
 arraycontext/impl/pytato/fake_numpy.py   | 36 +++++++----
 test/test_arraycontext.py                | 16 +++--
 6 files changed, 143 insertions(+), 55 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index a338059..aafcfd8 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -52,6 +52,8 @@ from .container.traversal import (
         rec_multimap_array_container,
         mapped_over_array_containers,
         multimapped_over_array_containers,
+        rec_map_reduce_array_container,
+        rec_multimap_reduce_array_container,
         thaw, freeze,
         from_numpy, to_numpy)
 
@@ -83,6 +85,7 @@ __all__ = (
         "rec_map_array_container", "rec_multimap_array_container",
         "mapped_over_array_containers",
         "multimapped_over_array_containers",
+        "rec_map_reduce_array_container", "rec_multimap_reduce_array_container",
         "thaw", "freeze",
         "from_numpy", "to_numpy",
 
diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py
index 44b7e9e..5142c05 100644
--- a/arraycontext/container/__init__.py
+++ b/arraycontext/container/__init__.py
@@ -139,7 +139,7 @@ def is_array_container_type(cls: type) -> bool:
     return (
             cls is ArrayContainer
             or (serialize_container.dispatch(cls)
-                is not serialize_container.__wrapped__))    # type: ignore
+                is not serialize_container.__wrapped__))  # type:ignore[attr-defined]
 
 
 def is_array_container(ary: Any) -> bool:
@@ -148,7 +148,7 @@ def is_array_container(ary: Any) -> bool:
         :func:`serialize_container`.
     """
     return (serialize_container.dispatch(ary.__class__)
-            is not serialize_container.__wrapped__)         # type: ignore
+            is not serialize_container.__wrapped__)       # type:ignore[attr-defined]
 
 
 @singledispatch
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 70babe6..cfdea4c 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -8,6 +8,9 @@
 .. autofunction:: rec_map_array_container
 .. autofunction:: rec_multimap_array_container
 
+.. autofunction:: rec_map_reduce_array_container
+.. autofunction:: rec_multimap_reduce_array_container
+
 Traversing decorators
 ~~~~~~~~~~~~~~~~~~~~~
 .. autofunction:: mapped_over_array_containers
@@ -48,7 +51,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Any, Callable, List, Optional, Union, Tuple
+from typing import Any, Callable, Iterable, List, Optional, Union, Tuple
 from functools import update_wrapper, partial, singledispatch
 
 import numpy as np
@@ -59,7 +62,7 @@ from arraycontext.container import (
         serialize_container, deserialize_container)
 
 
-# {{{ array container traversal
+# {{{ array container traversal helpers
 
 def _map_array_container_impl(
         f: Callable[[Any], Any],
@@ -78,8 +81,8 @@ def _map_array_container_impl(
             return f(_ary)
         elif is_array_container(_ary):
             return deserialize_container(_ary, [
-                    (key, frec(subary)) for key, subary in serialize_container(_ary)
-                    ])
+                (key, frec(subary)) for key, subary in serialize_container(_ary)
+                ])
         else:
             return f(_ary)
 
@@ -90,6 +93,7 @@ def _map_array_container_impl(
 def _multimap_array_container_impl(
         f: Callable[..., Any],
         *args: Any,
+        reduce_func: Callable[[Any, Iterable[Tuple[Any, Any]]], Any] = None,
         leaf_cls: Optional[type] = None,
         recursive: bool = False) -> ArrayContainerT:
     """Helper for :func:`rec_multimap_array_container`.
@@ -124,9 +128,9 @@ def _multimap_array_container_impl(
 
                 new_args[i] = subary
 
-            result.append((key, frec(*new_args)))       # type: ignore
+            result.append((key, frec(*new_args)))       # type: ignore[operator]
 
-        return deserialize_container(template_ary, result)
+        return process_container(template_ary, result)     # type: ignore[operator]
 
     container_indices: List[int] = [
             i for i, arg in enumerate(args)
@@ -135,7 +139,7 @@ def _multimap_array_container_impl(
     if not container_indices:
         return f(*args)
 
-    if len(container_indices) == 1:
+    if len(container_indices) == 1 and reduce_func is None:
         # NOTE: if we just have one ArrayContainer in args, passing it through
         # _map_array_container_impl should be faster
         def wrapper(ary: ArrayContainerT) -> ArrayContainerT:
@@ -149,9 +153,15 @@ def _multimap_array_container_impl(
                 wrapper, template_ary,
                 leaf_cls=leaf_cls, recursive=recursive)
 
+    process_container = deserialize_container if reduce_func is None else reduce_func
     frec = rec if recursive else f
+
     return rec(*args)
 
+# }}}
+
+
+# {{{ array container traversal
 
 def map_array_container(
         f: Callable[[Any], Any],
@@ -233,6 +243,10 @@ def multimapped_over_array_containers(
     update_wrapper(wrapper, f)
     return wrapper
 
+# }}}
+
+
+# {{{ keyed array container traversal
 
 def keyed_map_array_container(f: Callable[[Any, Any], Any],
                               ary: ArrayContainerT) -> ArrayContainerT:
@@ -266,9 +280,8 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any],
     def rec(keys: Tuple[Union[str, int], ...],
             _ary: ArrayContainerT) -> ArrayContainerT:
         if is_array_container(_ary):
-
             return deserialize_container(_ary, [
-                    (key, rec(keys+(key,), subary))
+                    (key, rec(keys + (key,), subary))
                     for key, subary in serialize_container(_ary)
                     ])
         else:
@@ -279,6 +292,57 @@ def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any],
 # }}}
 
 
+# {{{ array container reductions
+
+def rec_map_reduce_array_container(
+        reduce_func: Callable[[Iterable[Any]], Any],
+        map_func: Callable[[Any], Any],
+        ary: ArrayContainerT) -> Any:
+    """Perform a map-reduce over array containers recursively.
+
+    :param reduce_func: callable used to reduce over the components of the
+        :class:`~arraycontext.ArrayContainer`.
+    :param map_func: callable used to map a single component of the
+        :class:`~arraycontext.ArrayContainer`. The callable takes arrays of
+        type :class:`arraycontext.ArrayContext.array_types` and returns an
+        array of the same type or a scalar.
+    """
+    def rec(_ary: ArrayContainerT) -> ArrayContainerT:
+        if is_array_container(_ary):
+            return reduce_func([
+                rec(subary) for _, subary in serialize_container(_ary)
+                ])
+        else:
+            return map_func(_ary)
+
+    return rec(ary)
+
+
+def rec_multimap_reduce_array_container(
+        reduce_func: Callable[[Iterable[Any]], Any],
+        map_func: Callable[..., Any],
+        *args: Any) -> Any:
+    """Perform a map-reduce over multiple array containers recursively.
+
+    :param reduce_func: callable used to reduce over the components of the
+        :class:`~arraycontext.ArrayContainer`.
+    :param map_func: callable used to map a single component of the
+        :class:`~arraycontext.ArrayContainer`. The callable takes arrays of
+        type :class:`arraycontext.ArrayContext.array_types` and returns an
+        array of the same type or a scalar.
+    """
+    # NOTE: this wrapper matches the signature of `deserialize_container`
+    # to make plugging into `_multimap_array_container_impl` easier
+    def _reduce_wrapper(ary: Any, iterable: Iterable[Tuple[Any, Any]]) -> Any:
+        return reduce_func([subary for _, subary in iterable])
+
+    return _multimap_array_container_impl(
+        map_func, *args,
+        reduce_func=_reduce_wrapper, leaf_cls=None, recursive=True)
+
+# }}}
+
+
 # {{{ freeze/thaw
 
 @singledispatch
diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py
index bd9eb08..01f80f3 100644
--- a/arraycontext/impl/pyopencl/fake_numpy.py
+++ b/arraycontext/impl/pyopencl/fake_numpy.py
@@ -26,13 +26,15 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from functools import partial
+from functools import partial, reduce
 import operator
 
 from arraycontext.fake_numpy import \
         BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace
-from arraycontext.container.traversal import (rec_multimap_array_container,
-                                              rec_map_array_container)
+from arraycontext.container.traversal import (
+        rec_multimap_array_container, rec_map_array_container,
+        rec_map_reduce_array_container,
+        )
 
 try:
     import pyopencl as cl  # noqa: F401
@@ -104,24 +106,35 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
         return rec_multimap_array_container(where_inner, criterion, then, else_)
 
     def sum(self, a, dtype=None):
-        result = cl_array.sum(a, dtype=dtype, queue=self._array_context.queue)
+        result = rec_map_reduce_array_container(
+                sum,
+                partial(cl_array.sum, dtype=dtype, queue=self._array_context.queue),
+                a)
+
         if not self._array_context._force_device_scalars:
             result = result.get()[()]
-
         return result
 
     def min(self, a):
-        result = cl_array.min(a, queue=self._array_context.queue)
+        queue = self._array_context.queue
+        result = rec_map_reduce_array_container(
+                partial(reduce, partial(cl_array.minimum, queue=queue)),
+                partial(cl_array.min, queue=queue),
+                a)
+
         if not self._array_context._force_device_scalars:
             result = result.get()[()]
-
         return result
 
     def max(self, a):
-        result = cl_array.max(a, queue=self._array_context.queue)
+        queue = self._array_context.queue
+        result = rec_map_reduce_array_container(
+                partial(reduce, partial(cl_array.maximum, queue=queue)),
+                partial(cl_array.max, queue=queue),
+                a)
+
         if not self._array_context._force_device_scalars:
             result = result.get()[()]
-
         return result
 
     def stack(self, arrays, axis=0):
@@ -163,25 +176,15 @@ class PyOpenCLFakeNumpyNamespace(BaseFakeNumpyNamespace):
         return rec_map_array_container(_rec_ravel, a)
 
     def vdot(self, x, y, dtype=None):
-        import pyopencl.array as cl_array
-        from arraycontext import is_array_container, serialize_container
-
-        def _rec_vdot(xi, yi):
-            if is_array_container(xi):
-                assert type(xi) == type(yi)
-                return sum(_rec_vdot(subxi, subyi)
-                    for (_, subxi), (_, subyi) in zip(
-                        serialize_container(xi), serialize_container(yi)
-                    ))
-            else:
-                result = cl_array.vdot(xi, yi,
-                    dtype=dtype, queue=self._array_context.queue)
-                if not self._array_context._force_device_scalars:
-                    result = result.get()[()]
-
-                return result
+        from arraycontext import rec_multimap_reduce_array_container
+        result = rec_multimap_reduce_array_container(
+                sum,
+                partial(cl_array.vdot, dtype=dtype, queue=self._array_context.queue),
+                x, y)
 
-        return _rec_vdot(x, y)
+        if not self._array_context._force_device_scalars:
+            result = result.get()[()]
+        return result
 
 # }}}
 
diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py
index 9793dde..f17a4ab 100644
--- a/arraycontext/impl/pytato/fake_numpy.py
+++ b/arraycontext/impl/pytato/fake_numpy.py
@@ -21,11 +21,15 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
-
-from arraycontext.fake_numpy import \
-        BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace
-from arraycontext.container.traversal import \
-        rec_multimap_array_container, rec_map_array_container
+from functools import partial, reduce
+
+from arraycontext.fake_numpy import (
+        BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
+        )
+from arraycontext.container.traversal import (
+        rec_multimap_array_container, rec_map_array_container,
+        rec_map_reduce_array_container,
+        )
 import pytato as pt
 
 
@@ -82,20 +86,26 @@ class PytatoFakeNumpyNamespace(BaseFakeNumpyNamespace):
         return rec_multimap_array_container(pt.where, criterion, then, else_)
 
     def sum(self, a, dtype=None):
-        if dtype not in [a.dtype, None]:
-            raise NotImplementedError
-        return pt.sum(a)
+        def _pt_sum(ary):
+            if dtype not in [ary.dtype, None]:
+                raise NotImplementedError
+
+            return pt.sum(ary)
+
+        return rec_map_reduce_array_container(sum, _pt_sum, a)
 
     def min(self, a):
-        return pt.amin(a)
+        return rec_map_reduce_array_container(
+                partial(reduce, pt.minimum), pt.amin, a)
 
     def max(self, a):
-        return pt.amax(a)
+        return rec_map_reduce_array_container(
+                partial(reduce, pt.maximum), pt.amax, a)
 
     def stack(self, arrays, axis=0):
-        return rec_multimap_array_container(lambda *args: pt.stack(arrays=args,
-                                                                   axis=axis),
-                                            *arrays)
+        return rec_multimap_array_container(
+                lambda *args: pt.stack(arrays=args, axis=axis),
+                *arrays)
 
     # {{{ relational operators
 
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index c9c26e2..5bb9fb7 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -253,19 +253,27 @@ def assert_close_to_numpy_in_containers(actx, op, args):
 # {{{ np.function same as numpy
 
 @pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [
-            ("sin", 1, np.float64),
-            ("sin", 1, np.complex128),
-            ("exp", 1, np.float64),
+            # float only
             ("arctan2", 2, np.float64),
             ("minimum", 2, np.float64),
             ("maximum", 2, np.float64),
             ("where", 3, np.float64),
+            ("min", 1, np.float64),
+            ("max", 1, np.float64),
+
+            # float + complex
+            ("sin", 1, np.float64),
+            ("sin", 1, np.complex128),
+            ("exp", 1, np.float64),
+            ("exp", 1, np.complex128),
             ("conj", 1, np.float64),
             ("conj", 1, np.complex128),
             ("vdot", 2, np.float64),
             ("vdot", 2, np.complex128),
             ("abs", 1, np.float64),
             ("abs", 1, np.complex128),
+            ("sum", 1, np.float64),
+            ("sum", 1, np.complex64),
             ])
 def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype):
     actx = actx_factory()
@@ -494,7 +502,7 @@ def test_dof_array_arithmetic_same_as_numpy(actx_factory):
 # {{{ reductions same as numpy
 
 @pytest.mark.parametrize("op", ["sum", "min", "max"])
-def test_dof_array_reductions_same_as_numpy(actx_factory, op):
+def test_reductions_same_as_numpy(actx_factory, op):
     actx = actx_factory()
 
     ary = np.random.randn(3000)
-- 
GitLab