diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 9a6367009e3948444fb72e6282497443fd48ef89..7d5daa14258ed8c81c12d3fe364c4890bd32ce3f 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -127,6 +127,7 @@ def _format_binary_op_str(op_str: str, def with_container_arithmetic( *, bcast_number: bool = True, + _bcast_actx_array_type: Optional[bool] = None, bcast_obj_array: Optional[bool] = None, bcast_numpy_array: bool = False, bcast_container_types: Optional[Tuple[type, ...]] = None, @@ -142,6 +143,11 @@ def with_container_arithmetic( :arg bcast_number: If *True*, numbers broadcast over the container (with the container as the 'outer' structure). + :arg _bcast_actx_array_type: If *True*, instances of base array types of the + container's array context are broadcasted over the container. Can be + *True* only if the container has *_cls_has_array_context_attr* set. + Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set, + else *False*. :arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over the container. (with the container as the 'inner' structure) :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast @@ -209,6 +215,16 @@ def with_container_arithmetic( if not bcast_obj_array and bcast_numpy_array: raise TypeError("bcast_obj_array must be set if bcast_numpy_array is") + if _bcast_actx_array_type is None: + if _cls_has_array_context_attr: + _bcast_actx_array_type = bcast_number + else: + _bcast_actx_array_type = False + else: + if _bcast_actx_array_type and not _cls_has_array_context_attr: + raise TypeError("_bcast_actx_array_type can be True only if " + "_cls_has_array_context_attr is set.") + if bcast_numpy_array: def numpy_pred(name: str) -> str: return f"isinstance({name}, np.ndarray)" @@ -331,7 +347,7 @@ def with_container_arithmetic( cls._serialize_init_arrays_code("arg1").items(), cls._serialize_init_arrays_code("arg2").items()) }) - bcast_init_args = cls._deserialize_init_arrays_code("arg1", { + bcast_same_cls_init_args = cls._deserialize_init_arrays_code("arg1", { key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2") for key_arg1, expr_arg1 in cls._serialize_init_arrays_code("arg1").items() @@ -357,10 +373,19 @@ def with_container_arithmetic( else: raise ValueError(msg)""") gen(f"return cls({zip_init_args})") + + if _bcast_actx_array_type: + bcast_actx_ary_types: Tuple[str, ...] = ( + "*arg1.array_context.array_types",) + else: + bcast_actx_ary_types = () + gen(f""" if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg2, {tup_str(outer_bcast_type_names)}): - return cls({bcast_init_args}) + if isinstance(arg2, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_same_cls_init_args}) if {numpy_pred("arg2")}: result = np.empty_like(arg2, dtype=object) for i in np.ndindex(arg2.shape): @@ -383,12 +408,20 @@ def with_container_arithmetic( for key_arg2, expr_arg2 in cls._serialize_init_arrays_code("arg2").items() }) + + if _bcast_actx_array_type: + bcast_actx_ary_types = ("*arg2.array_context.array_types",) + else: + bcast_actx_ary_types = () + gen(f""" def {fname}(arg2, arg1): # assert other.__cls__ is not cls if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg1, {tup_str(outer_bcast_type_names)}): + if isinstance(arg1, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): return cls({bcast_init_args}) if {numpy_pred("arg1")}: result = np.empty_like(arg1, dtype=object) @@ -406,6 +439,7 @@ def with_container_arithmetic( # This will evaluate the module, which is all we need. code = gen.get().rstrip()+"\n" + result_dict = {"_MODULE_SOURCE_CODE": code, "cls": cls} exec(compile(code, f"<container arithmetic for {cls.__name__}>", "exec"), result_dict) diff --git a/arraycontext/context.py b/arraycontext/context.py index 13ce197ab2f02a62c8fe0a221c618dc3f27555b2..3d2df31fa64451e5afd153d182eb7be6044f88e8 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -102,7 +102,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Sequence, Union, Callable, Any +from typing import Sequence, Union, Callable, Any, Tuple from abc import ABC, abstractmethod, abstractproperty import numpy as np @@ -144,6 +144,11 @@ class ArrayContext(ABC): Callables accessible through this namespace vectorize over object arrays, including :class:`arraycontext.ArrayContainer`\ s. + .. attribute:: array_types + + A :class:`tuple` of types that are the valid base array classes + the context can operate on. + .. automethod:: freeze .. automethod:: thaw .. automethod:: tag @@ -151,6 +156,8 @@ class ArrayContext(ABC): .. automethod:: compile """ + array_types: Tuple[type, ...] = () + def __init__(self): self.np = self._get_fake_numpy_namespace() diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 04319027d8a8d2415fdd0edb41063415356bca3c..98f3ffc8161dfe3eb57369c21e110575a4145e58 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -109,6 +109,7 @@ class PyOpenCLArrayContext(ArrayContext): DeprecationWarning, stacklevel=2) import pyopencl as cl + import pyopencl.array as cla super().__init__() self.context = queue.context @@ -137,6 +138,8 @@ class PyOpenCLArrayContext(ArrayContext): self._loopy_transform_cache: \ Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {} + self.array_types = (cla.Array,) + def _get_fake_numpy_namespace(self): from arraycontext.impl.pyopencl.fake_numpy import PyOpenCLFakeNumpyNamespace return PyOpenCLFakeNumpyNamespace(self) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index b6f035e4132a79d06276f5c16cc04fdfee04715c..2b262eb09d21ded0e0f9c80c2e765f17e5f5b5b1 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -65,9 +65,11 @@ class PytatoPyOpenCLArrayContext(ArrayContext): """ def __init__(self, queue, allocator=None): + import pytato as pt super().__init__() self.queue = queue self.allocator = allocator + self.array_types = (pt.Array, ) # unused, but necessary to keep the context alive self.context = self.queue.context diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index e1a2c08b4cc63026621b6ceb7b11734015baae53..a6536550dbb1531975b3fc9ef6b7dee85581d24c 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -877,6 +877,55 @@ def test_abs_complex(actx_factory): np.testing.assert_allclose(actx.to_numpy(abs_a), abs_a_ref) +@with_container_arithmetic( + bcast_obj_array=True, + bcast_numpy_array=True, + rel_comparison=True, + _cls_has_array_context_attr=True) +@dataclass_array_container +@dataclass(frozen=True) +class Foo: + u: DOFArray + + @property + def array_context(self): + return self.u.array_context + + +def test_leaf_array_type_broadcasting(actx_factory): + # test support for https://github.com/inducer/arraycontext/issues/49 + actx = actx_factory() + + foo = Foo(DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41, ))) + bar = foo + 4 + baz = foo + actx.from_numpy(4*np.ones((3, ))) + qux = actx.from_numpy(4*np.ones((3, ))) + foo + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(baz.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(qux.u[0])) + + def _actx_allows_scalar_broadcast(actx): + if not isinstance(actx, PyOpenCLArrayContext): + return True + else: + import pyopencl as cl + # See https://github.com/inducer/pyopencl/issues/498 + return cl.version.VERSION > (2021, 2, 5) + + if _actx_allows_scalar_broadcast(actx): + quux = foo + actx.from_numpy(np.array(4)) + quuz = actx.from_numpy(np.array(4)) + foo + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quux.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quuz.u[0])) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: