diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 7e5da8b7d557c4ffb64d616aaf88c477ab352a5c..9a6367009e3948444fb72e6282497443fd48ef89 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -41,13 +41,14 @@ import numpy as np T = TypeVar("T") +@enum.unique class _OpClass(enum.Enum): - ARITHMETIC = enum.auto - MATMUL = enum.auto - BITWISE = enum.auto - SHIFT = enum.auto - EQ_COMPARISON = enum.auto - REL_COMPARISON = enum.auto + ARITHMETIC = enum.auto() + MATMUL = enum.auto() + BITWISE = enum.auto() + SHIFT = enum.auto() + EQ_COMPARISON = enum.auto() + REL_COMPARISON = enum.auto() _UNARY_OP_AND_DUNDER = [ @@ -304,12 +305,25 @@ def with_container_arithmetic( # {{{ binary operators for dunder_name, op_str, reversible, op_cls in _BINARY_OP_AND_DUNDER: + fname = f"_{cls.__name__.lower()}_{dunder_name}" + if op_cls not in desired_op_classes: + # Leaving equality comparison at the default supplied by + # dataclasses is dangerous: Comparison of dataclass fields + # might return an array of truth values, and the dataclasses + # implementation of __eq__ might consider that 'truthy' enough, + # yielding bogus equality results. + if op_cls == _OpClass.EQ_COMPARISON: + gen(f"def {fname}(arg1, arg2):") + with Indentation(gen): + gen("return NotImplemented") + gen(f"cls.__{dunder_name}__ = {fname}") + gen("") + continue # {{{ "forward" binary operators - fname = f"_{cls.__name__.lower()}_{dunder_name}" zip_init_args = cls._deserialize_init_arrays_code("arg1", { same_key(key_arg1, key_arg2): _format_binary_op_str(op_str, expr_arg1, expr_arg2) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 5425dd412f69ac491a27e400dfd3b8b75f40be79..094ee1346964b9dac1d6a46cabba0f7b071dae76 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -90,6 +90,7 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts([ @with_container_arithmetic( bcast_obj_array=True, bcast_numpy_array=True, + bitwise=True, rel_comparison=True, _cls_has_array_context_attr=True) class DOFArray: @@ -519,9 +520,10 @@ def test_array_context_einsum_array_tripleprod(actx_factory, spec): # }}} -# {{{ test array container +# {{{ array container classes for test -@with_container_arithmetic(bcast_obj_array=False, rel_comparison=True) +@with_container_arithmetic(bcast_obj_array=False, + eq_comparison=False, rel_comparison=False) @dataclass_array_container @dataclass(frozen=True) class MyContainer: @@ -833,6 +835,23 @@ def test_actx_compile(actx_factory): np.testing.assert_allclose(result.v, 3.14*v_x) +def test_container_equality(actx_factory): + actx = actx_factory() + + ary_dof, _, _, dc_of_dofs, bcast_dc_of_dofs = \ + _get_test_containers(actx) + _, _, _, dc_of_dofs_2, bcast_dc_of_dofs_2 = \ + _get_test_containers(actx) + + # MyContainer sets eq_comparison to False, so equality comparison should + # not succeed. + dc = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) + dc2 = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) + assert dc != dc2 + + assert isinstance(bcast_dc_of_dofs == bcast_dc_of_dofs_2, MyContainerDOFBcast) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: