diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 0f274a75eec4e02c20b5644e83980ebbbfd0ac64..f66cd6f3288e455cc73a98d3374fd43449a5c0d7 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -492,9 +492,10 @@ def test_array_context_einsum_array_tripleprod(actx_factory, spec): # }}} -# {{{ test array container +# {{{ array container classes for test -@with_container_arithmetic(bcast_obj_array=False, rel_comparison=True) +@with_container_arithmetic(bcast_obj_array=False, + eq_comparison=False, rel_comparison=False) @dataclass_array_container @dataclass(frozen=True) class MyContainer: @@ -765,6 +766,23 @@ def test_norm_ord_none(actx_factory, ndim): np.testing.assert_allclose(actx.to_numpy(norm_a), norm_a_ref) +def test_container_equality(actx_factory): + actx = actx_factory() + + ary_dof, _, _, dc_of_dofs, bcast_dc_of_dofs = \ + _get_test_containers(actx) + _, _, _, dc_of_dofs_2, bcast_dc_of_dofs_2 = \ + _get_test_containers(actx) + + # MyContainer sets eq_comparison to False, so equality comparison should + # not succeed. + dc = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) + dc2 = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) + assert dc != dc2 + + assert isinstance(bcast_dc_of_dofs == bcast_dc_of_dofs_2, MyContainerDOFBcast) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: