diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index b19aa00238e450269204fb6a0512e648f44d64c0..02f6692ec453632dd998e834e4dd305722e560b6 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -30,10 +30,14 @@ THE SOFTWARE. """ +import numpy as np + + # {{{ with_container_arithmetic class _OpClass(enum.Enum): ARITHMETIC = enum.auto + MATMUL = enum.auto BITWISE = enum.auto SHIFT = enum.auto EQ_COMPARISON = enum.auto @@ -56,6 +60,8 @@ _BINARY_OP_AND_DUNDER = [ ("mod", "{} % {}", True, _OpClass.ARITHMETIC), ("divmod", "divmod({}, {})", True, _OpClass.ARITHMETIC), + ("matmul", "{} @ {}", True, _OpClass.MATMUL), + ("and", "{} & {}", True, _OpClass.BITWISE), ("or", "{} | {}", True, _OpClass.BITWISE), ("xor", "{} ^ {}", True, _OpClass.BITWISE), @@ -109,9 +115,10 @@ def _format_binary_op_str(op_str, arg1, arg2): return op_str.format(arg1, arg2) -def with_container_arithmetic( +def with_container_arithmetic(*, bcast_number=True, bcast_obj_array=None, bcast_numpy_array=False, - arithmetic=True, bitwise=False, shift=False, + bcast_container_types=None, + arithmetic=True, matmul=False, bitwise=False, shift=False, eq_comparison=None, rel_comparison=None): """A class decorator that implements built-in operators for array containers by propagating the operations to the elements of the container. @@ -122,6 +129,13 @@ def with_container_arithmetic( the container. (with the container as the 'inner' structure) :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast over the container. (with the container as the 'inner' structure) + If this is set to *True*, *bcast_obj_array* must also be *True*. + :arg bcast_container_types: A sequence of container types that will broadcast + over this container (with this container as the 'outer' structure). + :class:`numpy.ndarray` is permitted to be part of this sequence to + indicate that, in such broadcasting situations, this container should + be the 'outer' structure. In this case, *bcast_obj_array* + (and consequently *bcast_numpy_array*) must be *False*. :arg arithmetic: Implement the conventional arithmetic operators, including ``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as :func:`abs`. @@ -183,9 +197,18 @@ def with_container_arithmetic( def numpy_pred(name): return "False" # optimized away + if bcast_container_types is None: + bcast_container_types = () + + if np.ndarray in bcast_container_types and bcast_obj_array: + raise ValueError("If numpy.ndarray is part of bcast_container_types, " + "bcast_obj_array must be False.") + desired_op_classes = set() if arithmetic: desired_op_classes.add(_OpClass.ARITHMETIC) + if matmul: + desired_op_classes.add(_OpClass.MATMUL) if bitwise: desired_op_classes.add(_OpClass.BITWISE) if shift: @@ -215,10 +238,25 @@ def with_container_arithmetic( """) gen("") + if bcast_container_types: + for i, bct in enumerate(bcast_container_types): + gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}") + gen("") + outer_bcast_type_names = [ + f"_bctype{i}" for i in range(len(bcast_container_types))] + if bcast_number: + outer_bcast_type_names.append("Number") + def same_key(k1, k2): assert k1 == k2 return k1 + def tup_str(t): + if not t: + return "()" + else: + return "(%s,)" % ", ".join(t) + # {{{ unary operators for dunder_name, op_str, op_cls in _UNARY_OP_AND_DUNDER: @@ -266,10 +304,10 @@ def with_container_arithmetic( def {fname}(arg1, arg2): if arg2.__class__ is cls: return cls({zip_init_args}) - if {bcast_number}: # optimized away - if isinstance(arg2, Number): + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg2, {tup_str(outer_bcast_type_names)}): return cls({bcast_init_args}) - if {numpy_pred("arg2")}: + if {numpy_pred("arg2")}: # optimized away result = np.empty_like(arg2, dtype=object) for i in np.ndindex(arg2.shape): result[i] = {op_str.format("arg1", "arg2[i]")} @@ -294,10 +332,10 @@ def with_container_arithmetic( def {fname}(arg2, arg1): # assert other.__cls__ is not cls - if {bcast_number}: # optimized away - if isinstance(arg1, Number): + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg1, {tup_str(outer_bcast_type_names)}): return cls({bcast_init_args}) - if {numpy_pred("arg1")}: + if {numpy_pred("arg1")}: # optimized away result = np.empty_like(arg1, dtype=object) for i in np.ndindex(arg1.shape): result[i] = {op_str.format("arg1[i]", "arg2")} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index f1f78a16b1b7e35bb2570f86cced027ab7b5d286..57bc7761d10b7c8ab802a9817a71b1bdb0863514 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -461,6 +461,24 @@ class MyContainer: return self.mass.array_context +@with_container_arithmetic( + bcast_obj_array=False, + bcast_container_types=(DOFArray, np.ndarray), + matmul=True, + rel_comparison=True,) +@dataclass_array_container +@dataclass(frozen=True) +class MyContainerDOFBcast: + name: str + mass: DOFArray + momentum: np.ndarray + enthalpy: DOFArray + + @property + def array_context(self): + return self.mass.array_context + + def _get_test_containers(actx, ambient_dim=2): x = DOFArray(actx, (actx.from_numpy(np.random.randn(50_000)),)) @@ -471,18 +489,27 @@ def _get_test_containers(actx, ambient_dim=2): momentum=make_obj_array([x, x]), enthalpy=x) + # pylint: disable=unexpected-keyword-arg, no-value-for-parameter + bcast_dataclass_of_dofs = MyContainerDOFBcast( + name="container", + mass=x, + momentum=make_obj_array([x, x]), + enthalpy=x) + ary_dof = x ary_of_dofs = make_obj_array([x, x, x]) - mat_of_dofs = np.empty((2, 2), dtype=object) + mat_of_dofs = np.empty((3, 3), dtype=object) for i in np.ndindex(mat_of_dofs.shape): mat_of_dofs[i] = x - return ary_dof, ary_of_dofs, mat_of_dofs, dataclass_of_dofs + return (ary_dof, ary_of_dofs, mat_of_dofs, dataclass_of_dofs, + bcast_dataclass_of_dofs) def test_container_multimap(actx_factory): actx = actx_factory() - ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs = _get_test_containers(actx) + ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \ + _get_test_containers(actx) # {{{ check @@ -522,7 +549,8 @@ def test_container_multimap(actx_factory): def test_container_arithmetic(actx_factory): actx = actx_factory() - ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs = _get_test_containers(actx) + ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \ + _get_test_containers(actx) # {{{ check @@ -542,12 +570,38 @@ def test_container_arithmetic(actx_factory): with pytest.raises(TypeError): ary_of_dofs + dc_of_dofs + with pytest.raises(TypeError): + dc_of_dofs + ary_of_dofs + + with pytest.raises(TypeError): + ary_dof + dc_of_dofs + + with pytest.raises(TypeError): + dc_of_dofs + ary_dof + + bcast_result = ary_dof + bcast_dc_of_dofs + bcast_dc_of_dofs + ary_dof + + assert actx.np.linalg.norm(bcast_result.mass - 2*ary_of_dofs) < 1e-8 + + mock_gradient = MyContainerDOFBcast( + name="yo", + mass=ary_of_dofs, + momentum=mat_of_dofs, + enthalpy=ary_of_dofs) + + grad_matvec_result = mock_gradient @ ary_of_dofs + assert isinstance(grad_matvec_result.mass, DOFArray) + assert grad_matvec_result.momentum.shape == (3,) + assert actx.np.linalg.norm(grad_matvec_result.mass - 3*ary_of_dofs**2) < 1e-8 + # }}} def test_container_freeze_thaw(actx_factory): actx = actx_factory() - ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs = _get_test_containers(actx) + ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \ + _get_test_containers(actx) # {{{ check @@ -577,7 +631,8 @@ def test_container_freeze_thaw(actx_factory): def test_container_norm(actx_factory, ord): actx = actx_factory() - ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs = _get_test_containers(actx) + ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \ + _get_test_containers(actx) from pytools.obj_array import make_obj_array c = MyContainer(name="hey", mass=1, momentum=make_obj_array([2, 3]), enthalpy=5)