From 47dfe4f7ae7a751ff7cbdcda0c56446939d82a9c Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Wed, 11 Oct 2023 09:48:37 -0500 Subject: [PATCH] persistent_dict: add complex hashing, numpy scalar hashing (#184) * persistent_dict: add complex hashing, numpy scalar hashing * simplify * fix * better update_for_complex * convert numpy scalars to 1D array * use np.array instead * tobytes() is not stable, try str+dtype * fix number detection * Revert "tobytes() is not stable, try str+dtype" This reverts commit f1b766c791299b08bf4f7090b9b47b4dbfe2678a. * Revert "fix number detection" This reverts commit 078c75a4027146c8c4b27534e33b461e6cf908f1. * convert large float types to python types * Revert "Revert "fix number detection"" This reverts commit fc4350a8a526aebeea73f7ce15aa847bc76208da. --- pytools/persistent_dict.py | 59 +++++++++++++++++++++++++----------- test/test_persistent_dict.py | 41 +++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 703a633..6a9f0dd 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -238,26 +238,35 @@ class KeyBuilder: try: method = getattr(self, "update_for_"+tname) except AttributeError: - if ( - # Handling numpy >= 1.20, for which - # type(np.dtype("float32")) -> "dtype[float32]" - tname.startswith("dtype[") - # Handling numpy >= 1.25, for which - # type(np.dtype("float32")) -> "Float32DType" - or tname.endswith("DType") - ) and "numpy" in sys.modules: + if "numpy" in sys.modules: import numpy as np - if isinstance(key, np.dtype): - method = self.update_for_specific_dtype - elif issubclass(tp, Enum): - method = self.update_for_enum - - elif is_dataclass(tp): - method = self.update_for_dataclass - - elif _HAS_ATTRS and attrs.has(tp): - method = self.update_for_attrs + # Hashing numpy dtypes + if ( + # Handling numpy >= 1.20, for which + # type(np.dtype("float32")) -> "dtype[float32]" + tname.startswith("dtype[") + # Handling numpy >= 1.25, for which + # type(np.dtype("float32")) -> "Float32DType" + or tname.endswith("DType") + ): + if isinstance(key, np.dtype): + method = self.update_for_specific_dtype + + # Hashing numpy scalars + elif isinstance(key, np.number): + # Non-numpy scalars are handled above in the try block. + method = self.update_for_numpy_scalar + + if method is None: + if issubclass(tp, Enum): + method = self.update_for_enum + + elif is_dataclass(tp): + method = self.update_for_dataclass + + elif _HAS_ATTRS and attrs.has(tp): + method = self.update_for_attrs if method is not None: inner_key_hash = self.new_hash() @@ -314,6 +323,10 @@ class KeyBuilder: def update_for_float(key_hash, key): key_hash.update(key.hex().encode("utf8")) + @staticmethod + def update_for_complex(key_hash, key): + key_hash.update(repr(key).encode("utf-8")) + @staticmethod def update_for_str(key_hash, key): key_hash.update(key.encode("utf8")) @@ -350,6 +363,16 @@ class KeyBuilder: def update_for_specific_dtype(key_hash, key): key_hash.update(key.str.encode("utf8")) + @staticmethod + def update_for_numpy_scalar(key_hash, key): + import numpy as np + if hasattr(np, "complex256") and key.dtype == np.dtype("complex256"): + key_hash.update(repr(complex(key)).encode("utf8")) + elif hasattr(np, "float128") and key.dtype == np.dtype("float128"): + key_hash.update(repr(float(key)).encode("utf8")) + else: + key_hash.update(np.array(key).tobytes()) + def update_for_dataclass(self, key_hash, key): self.rec(key_hash, type(key_hash).__name__.encode("utf-8")) diff --git a/test/test_persistent_dict.py b/test/test_persistent_dict.py index 59dd491..754b652 100644 --- a/test/test_persistent_dict.py +++ b/test/test_persistent_dict.py @@ -380,6 +380,47 @@ def test_dtype_hashing(): assert keyb(np.dtype(np.float32)) == keyb(np.dtype(np.float32)) +def test_scalar_hashing(): + keyb = KeyBuilder() + + assert keyb(1) == keyb(1) + assert keyb(2) != keyb(1) + assert keyb(1.1) == keyb(1.1) + assert keyb(1+4j) == keyb(1+4j) + + try: + import numpy as np + except ImportError: + return + + assert keyb(np.int8(1)) == keyb(np.int8(1)) + assert keyb(np.int16(1)) == keyb(np.int16(1)) + assert keyb(np.int32(1)) == keyb(np.int32(1)) + assert keyb(np.int32(2)) != keyb(np.int32(1)) + assert keyb(np.int64(1)) == keyb(np.int64(1)) + assert keyb(1) == keyb(np.int64(1)) + assert keyb(1) != keyb(np.int32(1)) + + assert keyb(np.longlong(1)) == keyb(np.longlong(1)) + + assert keyb(np.float16(1.1)) == keyb(np.float16(1.1)) + assert keyb(np.float32(1.1)) == keyb(np.float32(1.1)) + assert keyb(np.float64(1.1)) == keyb(np.float64(1.1)) + if hasattr(np, "float128"): + assert keyb(np.float128(1.1)) == keyb(np.float128(1.1)) + + assert keyb(np.longfloat(1.1)) == keyb(np.longfloat(1.1)) + assert keyb(np.longdouble(1.1)) == keyb(np.longdouble(1.1)) + + assert keyb(np.complex64(1.1+2.2j)) == keyb(np.complex64(1.1+2.2j)) + assert keyb(np.complex128(1.1+2.2j)) == keyb(np.complex128(1.1+2.2j)) + if hasattr(np, "complex256"): + assert keyb(np.complex256(1.1+2.2j)) == keyb(np.complex256(1.1+2.2j)) + + assert keyb(np.longcomplex(1.1+2.2j)) == keyb(np.longcomplex(1.1+2.2j)) + assert keyb(np.clongdouble(1.1+2.2j)) == keyb(np.clongdouble(1.1+2.2j)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab