diff --git a/pyopencl/tools.py b/pyopencl/tools.py index 5b8146858119e233338b8e23d4414f1005e3cdda..d78070b053c7aa3460da19a65257a658b77238bb 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -334,10 +334,15 @@ class DtypedArgument(Argument): self.name, self.dtype) + def __eq__(self, other): + return (type(self) == type(other) + and self.dtype == other.dtype + and self.name == other.name) + class VectorArg(DtypedArgument): def __init__(self, dtype, name, with_offset=False): - DtypedArgument.__init__(self, dtype, name) + super().__init__(dtype, name) self.with_offset = with_offset def declarator(self): @@ -350,6 +355,10 @@ class VectorArg(DtypedArgument): return result + def __eq__(self, other): + return (super().__eq__(other) + and self.with_offset == other.with_offset) + class ScalarArg(DtypedArgument): def declarator(self): @@ -1025,6 +1034,11 @@ def is_spirv(s): # {{{ numpy key types builder class _NumpyTypesKeyBuilder(KeyBuilderBase): + def update_for_VectorArg(self, key_hash, key): # noqa: N802 + self.rec(key_hash, key.dtype) + self.update_for_str(key_hash, key.name) + self.rec(key_hash, key.with_offset) + def update_for_type(self, key_hash, key): if issubclass(key, np.generic): self.update_for_str(key_hash, key.__name__)