From be0f569e805f65942ec4c4ad04e6982c0fe42511 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <15399010+kaushikcfd@users.noreply.github.com> Date: Tue, 6 Jul 2021 09:40:10 -0500 Subject: [PATCH] Broadcast array arg in binary ops if it's a valid leaf array type (#51) * broadcast arithmetic if arg2 is actx's array * define and implement get_array_types Co-authored-by: Alex Fikl <alexfikl@gmail.com> * adds test_leaf_array_type_broadcasting * avoid using iff in the docs replaced with only if Co-authored-by: Alex Fikl <alexfikl@gmail.com> * adds docs for test, tests with scalars * ArrayContext.get_array_types() -> ArrayContext.array_types * bcast_actx_array_type -> _bcast_actx_array_type Co-authored-by: Andreas Kloeckner <andreask@illinois.edu> * formatting: remove unused import * leaf array bcast types: better code placement * make array_types a property * docs: grammar * better var naming: ary_types -> bcast_actx_ary_types * Revert "make array_types a property" This reverts commit e6b8b1b1860aadd6cd95269ca18dc22303b5b8b5. * ArrayContext: make array_types a class attribute Co-authored-by: Alex Fikl <alexfikl@gmail.com> Co-authored-by: Andreas Kloeckner <andreask@illinois.edu> Co-authored-by: Andreas Kloeckner <inform@tiker.net> --- arraycontext/container/arithmetic.py | 42 +++++++++++++++++++--- arraycontext/context.py | 9 ++++- arraycontext/impl/pyopencl/__init__.py | 3 ++ arraycontext/impl/pytato/__init__.py | 2 ++ test/test_arraycontext.py | 49 ++++++++++++++++++++++++++ 5 files changed, 100 insertions(+), 5 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 9a63670..7d5daa1 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -127,6 +127,7 @@ def _format_binary_op_str(op_str: str, def with_container_arithmetic( *, bcast_number: bool = True, + _bcast_actx_array_type: Optional[bool] = None, bcast_obj_array: Optional[bool] = None, bcast_numpy_array: bool = False, bcast_container_types: Optional[Tuple[type, ...]] = None, @@ -142,6 +143,11 @@ def with_container_arithmetic( :arg bcast_number: If *True*, numbers broadcast over the container (with the container as the 'outer' structure). + :arg _bcast_actx_array_type: If *True*, instances of base array types of the + container's array context are broadcasted over the container. Can be + *True* only if the container has *_cls_has_array_context_attr* set. + Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set, + else *False*. :arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over the container. (with the container as the 'inner' structure) :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast @@ -209,6 +215,16 @@ def with_container_arithmetic( if not bcast_obj_array and bcast_numpy_array: raise TypeError("bcast_obj_array must be set if bcast_numpy_array is") + if _bcast_actx_array_type is None: + if _cls_has_array_context_attr: + _bcast_actx_array_type = bcast_number + else: + _bcast_actx_array_type = False + else: + if _bcast_actx_array_type and not _cls_has_array_context_attr: + raise TypeError("_bcast_actx_array_type can be True only if " + "_cls_has_array_context_attr is set.") + if bcast_numpy_array: def numpy_pred(name: str) -> str: return f"isinstance({name}, np.ndarray)" @@ -331,7 +347,7 @@ def with_container_arithmetic( cls._serialize_init_arrays_code("arg1").items(), cls._serialize_init_arrays_code("arg2").items()) }) - bcast_init_args = cls._deserialize_init_arrays_code("arg1", { + bcast_same_cls_init_args = cls._deserialize_init_arrays_code("arg1", { key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2") for key_arg1, expr_arg1 in cls._serialize_init_arrays_code("arg1").items() @@ -357,10 +373,19 @@ def with_container_arithmetic( else: raise ValueError(msg)""") gen(f"return cls({zip_init_args})") + + if _bcast_actx_array_type: + bcast_actx_ary_types: Tuple[str, ...] = ( + "*arg1.array_context.array_types",) + else: + bcast_actx_ary_types = () + gen(f""" if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg2, {tup_str(outer_bcast_type_names)}): - return cls({bcast_init_args}) + if isinstance(arg2, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_same_cls_init_args}) if {numpy_pred("arg2")}: result = np.empty_like(arg2, dtype=object) for i in np.ndindex(arg2.shape): @@ -383,12 +408,20 @@ def with_container_arithmetic( for key_arg2, expr_arg2 in cls._serialize_init_arrays_code("arg2").items() }) + + if _bcast_actx_array_type: + bcast_actx_ary_types = ("*arg2.array_context.array_types",) + else: + bcast_actx_ary_types = () + gen(f""" def {fname}(arg2, arg1): # assert other.__cls__ is not cls if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg1, {tup_str(outer_bcast_type_names)}): + if isinstance(arg1, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): return cls({bcast_init_args}) if {numpy_pred("arg1")}: result = np.empty_like(arg1, dtype=object) @@ -406,6 +439,7 @@ def with_container_arithmetic( # This will evaluate the module, which is all we need. code = gen.get().rstrip()+"\n" + result_dict = {"_MODULE_SOURCE_CODE": code, "cls": cls} exec(compile(code, f"<container arithmetic for {cls.__name__}>", "exec"), result_dict) diff --git a/arraycontext/context.py b/arraycontext/context.py index 13ce197..3d2df31 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -102,7 +102,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Sequence, Union, Callable, Any +from typing import Sequence, Union, Callable, Any, Tuple from abc import ABC, abstractmethod, abstractproperty import numpy as np @@ -144,6 +144,11 @@ class ArrayContext(ABC): Callables accessible through this namespace vectorize over object arrays, including :class:`arraycontext.ArrayContainer`\ s. + .. attribute:: array_types + + A :class:`tuple` of types that are the valid base array classes + the context can operate on. + .. automethod:: freeze .. automethod:: thaw .. automethod:: tag @@ -151,6 +156,8 @@ class ArrayContext(ABC): .. automethod:: compile """ + array_types: Tuple[type, ...] = () + def __init__(self): self.np = self._get_fake_numpy_namespace() diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 0431902..98f3ffc 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -109,6 +109,7 @@ class PyOpenCLArrayContext(ArrayContext): DeprecationWarning, stacklevel=2) import pyopencl as cl + import pyopencl.array as cla super().__init__() self.context = queue.context @@ -137,6 +138,8 @@ class PyOpenCLArrayContext(ArrayContext): self._loopy_transform_cache: \ Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {} + self.array_types = (cla.Array,) + def _get_fake_numpy_namespace(self): from arraycontext.impl.pyopencl.fake_numpy import PyOpenCLFakeNumpyNamespace return PyOpenCLFakeNumpyNamespace(self) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index b6f035e..2b262eb 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -65,9 +65,11 @@ class PytatoPyOpenCLArrayContext(ArrayContext): """ def __init__(self, queue, allocator=None): + import pytato as pt super().__init__() self.queue = queue self.allocator = allocator + self.array_types = (pt.Array, ) # unused, but necessary to keep the context alive self.context = self.queue.context diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index e1a2c08..a653655 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -877,6 +877,55 @@ def test_abs_complex(actx_factory): np.testing.assert_allclose(actx.to_numpy(abs_a), abs_a_ref) +@with_container_arithmetic( + bcast_obj_array=True, + bcast_numpy_array=True, + rel_comparison=True, + _cls_has_array_context_attr=True) +@dataclass_array_container +@dataclass(frozen=True) +class Foo: + u: DOFArray + + @property + def array_context(self): + return self.u.array_context + + +def test_leaf_array_type_broadcasting(actx_factory): + # test support for https://github.com/inducer/arraycontext/issues/49 + actx = actx_factory() + + foo = Foo(DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41, ))) + bar = foo + 4 + baz = foo + actx.from_numpy(4*np.ones((3, ))) + qux = actx.from_numpy(4*np.ones((3, ))) + foo + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(baz.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(qux.u[0])) + + def _actx_allows_scalar_broadcast(actx): + if not isinstance(actx, PyOpenCLArrayContext): + return True + else: + import pyopencl as cl + # See https://github.com/inducer/pyopencl/issues/498 + return cl.version.VERSION > (2021, 2, 5) + + if _actx_allows_scalar_broadcast(actx): + quux = foo + actx.from_numpy(np.array(4)) + quuz = actx.from_numpy(np.array(4)) + foo + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quux.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quuz.u[0])) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab