diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index f3fb4af79a98fa732b1a2f4a784cd2952377cdef..52c87d76800fb96b7070da68cf8a065a2d3f7b9a 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -304,12 +304,25 @@ def with_container_arithmetic( # {{{ binary operators for dunder_name, op_str, reversible, op_cls in _BINARY_OP_AND_DUNDER: + fname = f"_{cls.__name__.lower()}_{dunder_name}" + if op_cls not in desired_op_classes: + # Leaving equality comparison at the default supplied by + # dataclasses is dangerous: Comparison of dataclass fields + # might return an array of truth values, and the dataclasses + # implementation of __eq__ might consider that 'truthy' enough, + # yielding bogus equality results. + if op_cls == _OpClass.EQ_COMPARISON: + gen(f"def {fname}(arg1, arg2):") + with Indentation(gen): + gen("return NotImplemented") + gen(f"cls.__{dunder_name}__ = {fname}") + gen("") + continue # {{{ "forward" binary operators - fname = f"_{cls.__name__.lower()}_{dunder_name}" zip_init_args = cls._deserialize_init_arrays_code("arg1", { same_key(key_arg1, key_arg2): _format_binary_op_str(op_str, expr_arg1, expr_arg2)