From 89c38431d343f5b1a92dab4b5689c14159f5a865 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 20 Jan 2021 11:54:11 -0600 Subject: [PATCH] Make VectorArg, ScalarArg comparable and add persistent-dict key generation for them --- pyopencl/tools.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pyopencl/tools.py b/pyopencl/tools.py index 5b814685..d78070b0 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -334,10 +334,15 @@ class DtypedArgument(Argument): self.name, self.dtype) + def __eq__(self, other): + return (type(self) == type(other) + and self.dtype == other.dtype + and self.name == other.name) + class VectorArg(DtypedArgument): def __init__(self, dtype, name, with_offset=False): - DtypedArgument.__init__(self, dtype, name) + super().__init__(dtype, name) self.with_offset = with_offset def declarator(self): @@ -350,6 +355,10 @@ class VectorArg(DtypedArgument): return result + def __eq__(self, other): + return (super().__eq__(other) + and self.with_offset == other.with_offset) + class ScalarArg(DtypedArgument): def declarator(self): @@ -1025,6 +1034,11 @@ def is_spirv(s): # {{{ numpy key types builder class _NumpyTypesKeyBuilder(KeyBuilderBase): + def update_for_VectorArg(self, key_hash, key): # noqa: N802 + self.rec(key_hash, key.dtype) + self.update_for_str(key_hash, key.name) + self.rec(key_hash, key.with_offset) + def update_for_type(self, key_hash, key): if issubclass(key, np.generic): self.update_for_str(key_hash, key.__name__) -- GitLab