diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 703a6336d8ba8ae8fbebe894ef4b198f6bb87c01..6a9f0ddafcb58ce27060ec4592259eadd2b726d6 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -238,26 +238,35 @@ class KeyBuilder: try: method = getattr(self, "update_for_"+tname) except AttributeError: - if ( - # Handling numpy >= 1.20, for which - # type(np.dtype("float32")) -> "dtype[float32]" - tname.startswith("dtype[") - # Handling numpy >= 1.25, for which - # type(np.dtype("float32")) -> "Float32DType" - or tname.endswith("DType") - ) and "numpy" in sys.modules: + if "numpy" in sys.modules: import numpy as np - if isinstance(key, np.dtype): - method = self.update_for_specific_dtype - elif issubclass(tp, Enum): - method = self.update_for_enum - - elif is_dataclass(tp): - method = self.update_for_dataclass - - elif _HAS_ATTRS and attrs.has(tp): - method = self.update_for_attrs + # Hashing numpy dtypes + if ( + # Handling numpy >= 1.20, for which + # type(np.dtype("float32")) -> "dtype[float32]" + tname.startswith("dtype[") + # Handling numpy >= 1.25, for which + # type(np.dtype("float32")) -> "Float32DType" + or tname.endswith("DType") + ): + if isinstance(key, np.dtype): + method = self.update_for_specific_dtype + + # Hashing numpy scalars + elif isinstance(key, np.number): + # Non-numpy scalars are handled above in the try block. + method = self.update_for_numpy_scalar + + if method is None: + if issubclass(tp, Enum): + method = self.update_for_enum + + elif is_dataclass(tp): + method = self.update_for_dataclass + + elif _HAS_ATTRS and attrs.has(tp): + method = self.update_for_attrs if method is not None: inner_key_hash = self.new_hash() @@ -314,6 +323,10 @@ class KeyBuilder: def update_for_float(key_hash, key): key_hash.update(key.hex().encode("utf8")) + @staticmethod + def update_for_complex(key_hash, key): + key_hash.update(repr(key).encode("utf-8")) + @staticmethod def update_for_str(key_hash, key): key_hash.update(key.encode("utf8")) @@ -350,6 +363,16 @@ class KeyBuilder: def update_for_specific_dtype(key_hash, key): key_hash.update(key.str.encode("utf8")) + @staticmethod + def update_for_numpy_scalar(key_hash, key): + import numpy as np + if hasattr(np, "complex256") and key.dtype == np.dtype("complex256"): + key_hash.update(repr(complex(key)).encode("utf8")) + elif hasattr(np, "float128") and key.dtype == np.dtype("float128"): + key_hash.update(repr(float(key)).encode("utf8")) + else: + key_hash.update(np.array(key).tobytes()) + def update_for_dataclass(self, key_hash, key): self.rec(key_hash, type(key_hash).__name__.encode("utf-8")) diff --git a/test/test_persistent_dict.py b/test/test_persistent_dict.py index 59dd49139aaf827995ce0fde9e263e34fccaa79c..754b65247eb538603d144f7602a90192969543ce 100644 --- a/test/test_persistent_dict.py +++ b/test/test_persistent_dict.py @@ -380,6 +380,47 @@ def test_dtype_hashing(): assert keyb(np.dtype(np.float32)) == keyb(np.dtype(np.float32)) +def test_scalar_hashing(): + keyb = KeyBuilder() + + assert keyb(1) == keyb(1) + assert keyb(2) != keyb(1) + assert keyb(1.1) == keyb(1.1) + assert keyb(1+4j) == keyb(1+4j) + + try: + import numpy as np + except ImportError: + return + + assert keyb(np.int8(1)) == keyb(np.int8(1)) + assert keyb(np.int16(1)) == keyb(np.int16(1)) + assert keyb(np.int32(1)) == keyb(np.int32(1)) + assert keyb(np.int32(2)) != keyb(np.int32(1)) + assert keyb(np.int64(1)) == keyb(np.int64(1)) + assert keyb(1) == keyb(np.int64(1)) + assert keyb(1) != keyb(np.int32(1)) + + assert keyb(np.longlong(1)) == keyb(np.longlong(1)) + + assert keyb(np.float16(1.1)) == keyb(np.float16(1.1)) + assert keyb(np.float32(1.1)) == keyb(np.float32(1.1)) + assert keyb(np.float64(1.1)) == keyb(np.float64(1.1)) + if hasattr(np, "float128"): + assert keyb(np.float128(1.1)) == keyb(np.float128(1.1)) + + assert keyb(np.longfloat(1.1)) == keyb(np.longfloat(1.1)) + assert keyb(np.longdouble(1.1)) == keyb(np.longdouble(1.1)) + + assert keyb(np.complex64(1.1+2.2j)) == keyb(np.complex64(1.1+2.2j)) + assert keyb(np.complex128(1.1+2.2j)) == keyb(np.complex128(1.1+2.2j)) + if hasattr(np, "complex256"): + assert keyb(np.complex256(1.1+2.2j)) == keyb(np.complex256(1.1+2.2j)) + + assert keyb(np.longcomplex(1.1+2.2j)) == keyb(np.longcomplex(1.1+2.2j)) + assert keyb(np.clongdouble(1.1+2.2j)) == keyb(np.clongdouble(1.1+2.2j)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])