From be0f569e805f65942ec4c4ad04e6982c0fe42511 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <15399010+kaushikcfd@users.noreply.github.com>
Date: Tue, 6 Jul 2021 09:40:10 -0500
Subject: [PATCH] Broadcast array arg in binary ops if it's a valid leaf array
 type (#51)

* broadcast arithmetic if arg2 is actx's array

* define and implement get_array_types

Co-authored-by: Alex Fikl <alexfikl@gmail.com>

* adds test_leaf_array_type_broadcasting

* avoid using iff in the docs

replaced with only if

Co-authored-by: Alex Fikl <alexfikl@gmail.com>

* adds docs for test, tests with scalars

* ArrayContext.get_array_types() -> ArrayContext.array_types

* bcast_actx_array_type -> _bcast_actx_array_type

Co-authored-by: Andreas Kloeckner <andreask@illinois.edu>

* formatting: remove unused import

* leaf array bcast types: better code placement

* make array_types a property

* docs: grammar

* better var naming: ary_types -> bcast_actx_ary_types

* Revert "make array_types a property"

This reverts commit e6b8b1b1860aadd6cd95269ca18dc22303b5b8b5.

* ArrayContext: make array_types a class attribute

Co-authored-by: Alex Fikl <alexfikl@gmail.com>
Co-authored-by: Andreas Kloeckner <andreask@illinois.edu>
Co-authored-by: Andreas Kloeckner <inform@tiker.net>
---
 arraycontext/container/arithmetic.py   | 42 +++++++++++++++++++---
 arraycontext/context.py                |  9 ++++-
 arraycontext/impl/pyopencl/__init__.py |  3 ++
 arraycontext/impl/pytato/__init__.py   |  2 ++
 test/test_arraycontext.py              | 49 ++++++++++++++++++++++++++
 5 files changed, 100 insertions(+), 5 deletions(-)

diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index 9a63670..7d5daa1 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -127,6 +127,7 @@ def _format_binary_op_str(op_str: str,
 def with_container_arithmetic(
         *,
         bcast_number: bool = True,
+        _bcast_actx_array_type: Optional[bool] = None,
         bcast_obj_array: Optional[bool] = None,
         bcast_numpy_array: bool = False,
         bcast_container_types: Optional[Tuple[type, ...]] = None,
@@ -142,6 +143,11 @@ def with_container_arithmetic(
 
     :arg bcast_number: If *True*, numbers broadcast over the container
         (with the container as the 'outer' structure).
+    :arg _bcast_actx_array_type: If *True*, instances of base array types of the
+        container's array context are broadcasted over the container. Can be
+        *True* only if the container has *_cls_has_array_context_attr* set.
+        Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set,
+        else *False*.
     :arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over
         the container.  (with the container as the 'inner' structure)
     :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast
@@ -209,6 +215,16 @@ def with_container_arithmetic(
     if not bcast_obj_array and bcast_numpy_array:
         raise TypeError("bcast_obj_array must be set if bcast_numpy_array is")
 
+    if _bcast_actx_array_type is None:
+        if _cls_has_array_context_attr:
+            _bcast_actx_array_type = bcast_number
+        else:
+            _bcast_actx_array_type = False
+    else:
+        if _bcast_actx_array_type and not _cls_has_array_context_attr:
+            raise TypeError("_bcast_actx_array_type can be True only if "
+                            "_cls_has_array_context_attr is set.")
+
     if bcast_numpy_array:
         def numpy_pred(name: str) -> str:
             return f"isinstance({name}, np.ndarray)"
@@ -331,7 +347,7 @@ def with_container_arithmetic(
                         cls._serialize_init_arrays_code("arg1").items(),
                         cls._serialize_init_arrays_code("arg2").items())
                     })
-            bcast_init_args = cls._deserialize_init_arrays_code("arg1", {
+            bcast_same_cls_init_args = cls._deserialize_init_arrays_code("arg1", {
                     key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2")
                     for key_arg1, expr_arg1 in
                     cls._serialize_init_arrays_code("arg1").items()
@@ -357,10 +373,19 @@ def with_container_arithmetic(
                                 else:
                                     raise ValueError(msg)""")
                     gen(f"return cls({zip_init_args})")
+
+                if _bcast_actx_array_type:
+                    bcast_actx_ary_types: Tuple[str, ...] = (
+                        "*arg1.array_context.array_types",)
+                else:
+                    bcast_actx_ary_types = ()
+
                 gen(f"""
                 if {bool(outer_bcast_type_names)}:  # optimized away
-                    if isinstance(arg2, {tup_str(outer_bcast_type_names)}):
-                        return cls({bcast_init_args})
+                    if isinstance(arg2,
+                                  {tup_str(outer_bcast_type_names
+                                           + bcast_actx_ary_types)}):
+                        return cls({bcast_same_cls_init_args})
                 if {numpy_pred("arg2")}:
                     result = np.empty_like(arg2, dtype=object)
                     for i in np.ndindex(arg2.shape):
@@ -383,12 +408,20 @@ def with_container_arithmetic(
                         for key_arg2, expr_arg2 in
                         cls._serialize_init_arrays_code("arg2").items()
                         })
+
+                if _bcast_actx_array_type:
+                    bcast_actx_ary_types = ("*arg2.array_context.array_types",)
+                else:
+                    bcast_actx_ary_types = ()
+
                 gen(f"""
                     def {fname}(arg2, arg1):
                         # assert other.__cls__ is not cls
 
                         if {bool(outer_bcast_type_names)}:  # optimized away
-                            if isinstance(arg1, {tup_str(outer_bcast_type_names)}):
+                            if isinstance(arg1,
+                                          {tup_str(outer_bcast_type_names
+                                                   + bcast_actx_ary_types)}):
                                 return cls({bcast_init_args})
                         if {numpy_pred("arg1")}:
                             result = np.empty_like(arg1, dtype=object)
@@ -406,6 +439,7 @@ def with_container_arithmetic(
 
         # This will evaluate the module, which is all we need.
         code = gen.get().rstrip()+"\n"
+
         result_dict = {"_MODULE_SOURCE_CODE": code, "cls": cls}
         exec(compile(code, f"<container arithmetic for {cls.__name__}>", "exec"),
                 result_dict)
diff --git a/arraycontext/context.py b/arraycontext/context.py
index 13ce197..3d2df31 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -102,7 +102,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Sequence, Union, Callable, Any
+from typing import Sequence, Union, Callable, Any, Tuple
 from abc import ABC, abstractmethod, abstractproperty
 
 import numpy as np
@@ -144,6 +144,11 @@ class ArrayContext(ABC):
          Callables accessible through this namespace vectorize over object
          arrays, including :class:`arraycontext.ArrayContainer`\ s.
 
+    .. attribute:: array_types
+
+        A :class:`tuple` of types that are the valid base array classes
+        the context can operate on.
+
     .. automethod:: freeze
     .. automethod:: thaw
     .. automethod:: tag
@@ -151,6 +156,8 @@ class ArrayContext(ABC):
     .. automethod:: compile
     """
 
+    array_types: Tuple[type, ...] = ()
+
     def __init__(self):
         self.np = self._get_fake_numpy_namespace()
 
diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index 0431902..98f3ffc 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -109,6 +109,7 @@ class PyOpenCLArrayContext(ArrayContext):
                     DeprecationWarning, stacklevel=2)
 
         import pyopencl as cl
+        import pyopencl.array as cla
 
         super().__init__()
         self.context = queue.context
@@ -137,6 +138,8 @@ class PyOpenCLArrayContext(ArrayContext):
         self._loopy_transform_cache: \
                 Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {}
 
+        self.array_types = (cla.Array,)
+
     def _get_fake_numpy_namespace(self):
         from arraycontext.impl.pyopencl.fake_numpy import PyOpenCLFakeNumpyNamespace
         return PyOpenCLFakeNumpyNamespace(self)
diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index b6f035e..2b262eb 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -65,9 +65,11 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
     """
 
     def __init__(self, queue, allocator=None):
+        import pytato as pt
         super().__init__()
         self.queue = queue
         self.allocator = allocator
+        self.array_types = (pt.Array, )
 
         # unused, but necessary to keep the context alive
         self.context = self.queue.context
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index e1a2c08..a653655 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -877,6 +877,55 @@ def test_abs_complex(actx_factory):
     np.testing.assert_allclose(actx.to_numpy(abs_a), abs_a_ref)
 
 
+@with_container_arithmetic(
+    bcast_obj_array=True,
+    bcast_numpy_array=True,
+    rel_comparison=True,
+    _cls_has_array_context_attr=True)
+@dataclass_array_container
+@dataclass(frozen=True)
+class Foo:
+    u: DOFArray
+
+    @property
+    def array_context(self):
+        return self.u.array_context
+
+
+def test_leaf_array_type_broadcasting(actx_factory):
+    # test support for https://github.com/inducer/arraycontext/issues/49
+    actx = actx_factory()
+
+    foo = Foo(DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41, )))
+    bar = foo + 4
+    baz = foo + actx.from_numpy(4*np.ones((3, )))
+    qux = actx.from_numpy(4*np.ones((3, ))) + foo
+
+    np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
+                               actx.to_numpy(baz.u[0]))
+
+    np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
+                               actx.to_numpy(qux.u[0]))
+
+    def _actx_allows_scalar_broadcast(actx):
+        if not isinstance(actx, PyOpenCLArrayContext):
+            return True
+        else:
+            import pyopencl as cl
+            # See https://github.com/inducer/pyopencl/issues/498
+            return cl.version.VERSION > (2021, 2, 5)
+
+    if _actx_allows_scalar_broadcast(actx):
+        quux = foo + actx.from_numpy(np.array(4))
+        quuz = actx.from_numpy(np.array(4)) + foo
+
+        np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
+                                   actx.to_numpy(quux.u[0]))
+
+        np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
+                                   actx.to_numpy(quuz.u[0]))
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab