diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index dd68c950e04e29fc99c456412d9dc4a53dbc61b2..975d7b3efe4bcc419a7ca004e1df3b0fbd39d5d9 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -385,6 +385,14 @@ class ArrayArg(ArrayBase, KernelArgument): + " aspace: %s" % aspace_str) + def update_persistent_hash(self, key_hash, key_builder): + """Custom hash computation function for use with + :class:`pytools.persistent_dict.PersistentDict`. + """ + super(ArrayArg, self).update_persistent_hash(key_hash, key_builder) + key_builder.rec(key_hash, self.address_space) + key_builder.rec(key_hash, self.is_output_only) + # Making this a function prevents incorrect use in isinstance. # Note: This is *not* deprecated, as it is super-common and diff --git a/test/test_loopy.py b/test/test_loopy.py index 231b70bf71d865e5b9832332c90f3228a0a26b82..119d57adf2c850eba3bb6ad5df3c0a8d0644b70c 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -3003,6 +3003,18 @@ def test_shape_mismatch_check(ctx_factory): prg(queue, a=a, b=b) +def test_array_arg_extra_kwargs_persis_hash(): + from loopy.tools import LoopyKeyBuilder + + a = lp.ArrayArg('a', shape=(10, ), dtype=np.float64, + address_space=lp.AddressSpace.LOCAL) + not_a = lp.ArrayArg('a', shape=(10, ), dtype=np.float64, + address_space=lp.AddressSpace.PRIVATE) + + key_builder = LoopyKeyBuilder() + assert key_builder(a) != key_builder(not_a) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])