diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 084c37b45cc4af25689ae3e121f170382c4e8d16..4e2819a82162b26a8c53dc25b434990473ed2d2c 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -1474,6 +1474,9 @@ class LoopKernel(ImmutableRecordWithoutPickling): return hash(key_hash.digest()) def __eq__(self, other): + if self is other: + return True + if not isinstance(other, LoopKernel): return False diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index 5d4240b9ab3e1ce2ad356a93b5e21b3bbf4d499e..7e3cf913339678731fedf54d3fd5abebcb816d3a 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -808,6 +808,9 @@ class ArrayBase(ImmutableRecord): **kwargs) def __eq__(self, other): + if self is other: + return True + from loopy.symbolic import ( is_tuple_of_expressions_equal as istoee, is_expression_equal as isee) diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index 94b31df12dae516d3539438b7e4ed66ed765e697..facb3c9a05683d218684d921eb72a5e9438fd1fd 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -480,6 +480,9 @@ class TemporaryVariable(ArrayBase): " scope:%s" % scope_str) def __eq__(self, other): + if self is other: + return True + return ( super(TemporaryVariable, self).__eq__(other) and self.storage_shape == other.storage_shape diff --git a/loopy/target/__init__.py b/loopy/target/__init__.py index 7e307ef8bdd4d89e24b26dbacf39733ab3350307..42eef5b504d44a93a9c28c7862e4ecb89e899242 100644 --- a/loopy/target/__init__.py +++ b/loopy/target/__init__.py @@ -60,6 +60,9 @@ class TargetBase(object): key_builder.rec(key_hash, getattr(self, field_name)) def __eq__(self, other): + if self is other: + return True + if type(self) != type(other): return False diff --git a/loopy/tools.py b/loopy/tools.py index 1ebbe5c8a4fd2b68e3bfcf5ed493384599dac2c5..d26fa2ae903d28b4a88b3d5ae246b58c9d70efe4 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -537,6 +537,9 @@ class LazilyUnpicklingListWithEqAndPersistentHashing(LazilyUnpicklingList): return self.persistent_hash_key_getter(obj) def __eq__(self, other): + if self is other: + return True + if not isinstance(other, (list, LazilyUnpicklingList)): return NotImplemented