From e5ff3c6eaeba199766598384b6b8173f37b38714 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Andreas=20Kl=C3=B6ckner?= <inform@tiker.net>
Date: Sat, 29 May 2021 17:00:32 -0500
Subject: [PATCH] Implement, test container-over-container broadcasting (#15)

* Implement, test container-over-container broadcasting

* Tweak bcast_container_types arg doc

Co-authored-by: Alex Fikl <alexfikl@gmail.com>

* Improvements to with_container_arithmetic

- Support matmul with a separate operator class
- Join container and number broadcasting code paths
- Make args keyword-only

Co-authored-by: Alex Fikl <alexfikl@gmail.com>
---
 arraycontext/container/arithmetic.py | 54 ++++++++++++++++++----
 test/test_arraycontext.py            | 67 +++++++++++++++++++++++++---
 2 files changed, 107 insertions(+), 14 deletions(-)

diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index b19aa00..02f6692 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -30,10 +30,14 @@ THE SOFTWARE.
 """
 
 
+import numpy as np
+
+
 # {{{ with_container_arithmetic
 
 class _OpClass(enum.Enum):
     ARITHMETIC = enum.auto
+    MATMUL = enum.auto
     BITWISE = enum.auto
     SHIFT = enum.auto
     EQ_COMPARISON = enum.auto
@@ -56,6 +60,8 @@ _BINARY_OP_AND_DUNDER = [
         ("mod", "{} % {}", True, _OpClass.ARITHMETIC),
         ("divmod", "divmod({}, {})", True, _OpClass.ARITHMETIC),
 
+        ("matmul", "{} @ {}", True, _OpClass.MATMUL),
+
         ("and", "{} & {}", True, _OpClass.BITWISE),
         ("or", "{} | {}", True, _OpClass.BITWISE),
         ("xor", "{} ^ {}", True, _OpClass.BITWISE),
@@ -109,9 +115,10 @@ def _format_binary_op_str(op_str, arg1, arg2):
         return op_str.format(arg1, arg2)
 
 
-def with_container_arithmetic(
+def with_container_arithmetic(*,
         bcast_number=True, bcast_obj_array=None, bcast_numpy_array=False,
-        arithmetic=True, bitwise=False, shift=False,
+        bcast_container_types=None,
+        arithmetic=True, matmul=False, bitwise=False, shift=False,
         eq_comparison=None, rel_comparison=None):
     """A class decorator that implements built-in operators for array containers
     by propagating the operations to the elements of the container.
@@ -122,6 +129,13 @@ def with_container_arithmetic(
         the container.  (with the container as the 'inner' structure)
     :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast
         over the container.  (with the container as the 'inner' structure)
+        If this is set to *True*, *bcast_obj_array* must also be *True*.
+    :arg bcast_container_types: A sequence of container types that will broadcast
+        over this container (with this container as the 'outer' structure).
+        :class:`numpy.ndarray` is permitted to be part of this sequence to
+        indicate that, in such broadcasting situations, this container should
+        be the 'outer' structure. In this case, *bcast_obj_array*
+        (and consequently *bcast_numpy_array*) must be *False*.
     :arg arithmetic: Implement the conventional arithmetic operators, including
         ``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as
         :func:`abs`.
@@ -183,9 +197,18 @@ def with_container_arithmetic(
         def numpy_pred(name):
             return "False"  # optimized away
 
+    if bcast_container_types is None:
+        bcast_container_types = ()
+
+    if np.ndarray in bcast_container_types and bcast_obj_array:
+        raise ValueError("If numpy.ndarray is part of bcast_container_types, "
+                "bcast_obj_array must be False.")
+
     desired_op_classes = set()
     if arithmetic:
         desired_op_classes.add(_OpClass.ARITHMETIC)
+    if matmul:
+        desired_op_classes.add(_OpClass.MATMUL)
     if bitwise:
         desired_op_classes.add(_OpClass.BITWISE)
     if shift:
@@ -215,10 +238,25 @@ def with_container_arithmetic(
             """)
         gen("")
 
+        if bcast_container_types:
+            for i, bct in enumerate(bcast_container_types):
+                gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}")
+            gen("")
+        outer_bcast_type_names = [
+                f"_bctype{i}" for i in range(len(bcast_container_types))]
+        if bcast_number:
+            outer_bcast_type_names.append("Number")
+
         def same_key(k1, k2):
             assert k1 == k2
             return k1
 
+        def tup_str(t):
+            if not t:
+                return "()"
+            else:
+                return "(%s,)" % ", ".join(t)
+
         # {{{ unary operators
 
         for dunder_name, op_str, op_cls in _UNARY_OP_AND_DUNDER:
@@ -266,10 +304,10 @@ def with_container_arithmetic(
                 def {fname}(arg1, arg2):
                     if arg2.__class__ is cls:
                         return cls({zip_init_args})
-                    if {bcast_number}:  # optimized away
-                        if isinstance(arg2, Number):
+                    if {bool(outer_bcast_type_names)}:  # optimized away
+                        if isinstance(arg2, {tup_str(outer_bcast_type_names)}):
                             return cls({bcast_init_args})
-                    if {numpy_pred("arg2")}:
+                    if {numpy_pred("arg2")}:  # optimized away
                         result = np.empty_like(arg2, dtype=object)
                         for i in np.ndindex(arg2.shape):
                             result[i] = {op_str.format("arg1", "arg2[i]")}
@@ -294,10 +332,10 @@ def with_container_arithmetic(
                     def {fname}(arg2, arg1):
                         # assert other.__cls__ is not cls
 
-                        if {bcast_number}:  # optimized away
-                            if isinstance(arg1, Number):
+                        if {bool(outer_bcast_type_names)}:  # optimized away
+                            if isinstance(arg1, {tup_str(outer_bcast_type_names)}):
                                 return cls({bcast_init_args})
-                        if {numpy_pred("arg1")}:
+                        if {numpy_pred("arg1")}:  # optimized away
                             result = np.empty_like(arg1, dtype=object)
                             for i in np.ndindex(arg1.shape):
                                 result[i] = {op_str.format("arg1[i]", "arg2")}
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index f1f78a1..57bc776 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -461,6 +461,24 @@ class MyContainer:
         return self.mass.array_context
 
 
+@with_container_arithmetic(
+        bcast_obj_array=False,
+        bcast_container_types=(DOFArray, np.ndarray),
+        matmul=True,
+        rel_comparison=True,)
+@dataclass_array_container
+@dataclass(frozen=True)
+class MyContainerDOFBcast:
+    name: str
+    mass: DOFArray
+    momentum: np.ndarray
+    enthalpy: DOFArray
+
+    @property
+    def array_context(self):
+        return self.mass.array_context
+
+
 def _get_test_containers(actx, ambient_dim=2):
     x = DOFArray(actx, (actx.from_numpy(np.random.randn(50_000)),))
 
@@ -471,18 +489,27 @@ def _get_test_containers(actx, ambient_dim=2):
             momentum=make_obj_array([x, x]),
             enthalpy=x)
 
+    # pylint: disable=unexpected-keyword-arg, no-value-for-parameter
+    bcast_dataclass_of_dofs = MyContainerDOFBcast(
+            name="container",
+            mass=x,
+            momentum=make_obj_array([x, x]),
+            enthalpy=x)
+
     ary_dof = x
     ary_of_dofs = make_obj_array([x, x, x])
-    mat_of_dofs = np.empty((2, 2), dtype=object)
+    mat_of_dofs = np.empty((3, 3), dtype=object)
     for i in np.ndindex(mat_of_dofs.shape):
         mat_of_dofs[i] = x
 
-    return ary_dof, ary_of_dofs, mat_of_dofs, dataclass_of_dofs
+    return (ary_dof, ary_of_dofs, mat_of_dofs, dataclass_of_dofs,
+            bcast_dataclass_of_dofs)
 
 
 def test_container_multimap(actx_factory):
     actx = actx_factory()
-    ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs = _get_test_containers(actx)
+    ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
+            _get_test_containers(actx)
 
     # {{{ check
 
@@ -522,7 +549,8 @@ def test_container_multimap(actx_factory):
 
 def test_container_arithmetic(actx_factory):
     actx = actx_factory()
-    ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs = _get_test_containers(actx)
+    ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
+            _get_test_containers(actx)
 
     # {{{ check
 
@@ -542,12 +570,38 @@ def test_container_arithmetic(actx_factory):
     with pytest.raises(TypeError):
         ary_of_dofs + dc_of_dofs
 
+    with pytest.raises(TypeError):
+        dc_of_dofs + ary_of_dofs
+
+    with pytest.raises(TypeError):
+        ary_dof + dc_of_dofs
+
+    with pytest.raises(TypeError):
+        dc_of_dofs + ary_dof
+
+    bcast_result = ary_dof + bcast_dc_of_dofs
+    bcast_dc_of_dofs + ary_dof
+
+    assert actx.np.linalg.norm(bcast_result.mass - 2*ary_of_dofs) < 1e-8
+
+    mock_gradient = MyContainerDOFBcast(
+            name="yo",
+            mass=ary_of_dofs,
+            momentum=mat_of_dofs,
+            enthalpy=ary_of_dofs)
+
+    grad_matvec_result = mock_gradient @ ary_of_dofs
+    assert isinstance(grad_matvec_result.mass, DOFArray)
+    assert grad_matvec_result.momentum.shape == (3,)
+    assert actx.np.linalg.norm(grad_matvec_result.mass - 3*ary_of_dofs**2) < 1e-8
+
     # }}}
 
 
 def test_container_freeze_thaw(actx_factory):
     actx = actx_factory()
-    ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs = _get_test_containers(actx)
+    ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
+            _get_test_containers(actx)
 
     # {{{ check
 
@@ -577,7 +631,8 @@ def test_container_freeze_thaw(actx_factory):
 def test_container_norm(actx_factory, ord):
     actx = actx_factory()
 
-    ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs = _get_test_containers(actx)
+    ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
+            _get_test_containers(actx)
 
     from pytools.obj_array import make_obj_array
     c = MyContainer(name="hey", mass=1, momentum=make_obj_array([2, 3]), enthalpy=5)
-- 
GitLab