diff --git a/loopy/target/pyopencl.py b/loopy/target/pyopencl.py index c5e8d0a7f7a9f70b3afe46e9d04a3bf861066329..826ba2a8f09b8a19d19200ef6d936a8276cf3688 100644 --- a/loopy/target/pyopencl.py +++ b/loopy/target/pyopencl.py @@ -299,7 +299,26 @@ class PyOpenCLTarget(OpenCLTarget): self.device = device self.pyopencl_module_name = pyopencl_module_name - comparison_fields = ["device"] + # NB: Not including 'device', as that is handled specially here. + hash_fields = OpenCLTarget.hash_fields + ( + "pyopencl_module_name",) + comparison_fields = OpenCLTarget.comparison_fields + ( + "pyopencl_module_name",) + + def __eq__(self, other): + if not super(PyOpenCLTarget, self).__eq__(other): + return False + + if (self.device is None) != (other.device is None): + return False + + if self.device is not None: + assert other.device is not None + return (self.device.persistent_unique_id + == other.device.persistent_unique_id) + else: + assert other.device is None + return True def update_persistent_hash(self, key_hash, key_builder): super(PyOpenCLTarget, self).update_persistent_hash(key_hash, key_builder)