From 2e9cf0d2b6f67bf99897a89b646b652fbd4df05e Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Tue, 12 Nov 2024 14:51:55 -0600 Subject: [PATCH] KeyBuilder: support np.bool{,_} for numpy<2 --- pytools/persistent_dict.py | 2 +- pytools/test/test_persistent_dict.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 62a7f5d..5d830bd 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -221,7 +221,7 @@ class KeyBuilder: method = self.update_for_specific_dtype # Hashing numpy scalars - elif isinstance(key, np.number): + elif isinstance(key, np.number | np.bool_): # Non-numpy scalars are handled above in the try block. method = self.update_for_numpy_scalar diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index 727e589..ed41c5d 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -414,6 +414,31 @@ def test_dtype_hashing() -> None: assert keyb(np.dtype(np.float32)) == keyb(np.dtype(np.float32)) +def test_bool_hashing() -> None: + keyb = KeyBuilder() + + assert keyb(True) == keyb(True) + assert keyb(False) == keyb(False) + assert keyb(True) != keyb(False) + + np = pytest.importorskip("numpy") + + bool_types = [np.bool_] + if hasattr(np, "bool"): + bool_types.append(np.bool) + + for bool_type in bool_types: + assert keyb(bool_type) != keyb(bool) + + assert keyb(bool_type(True)) == keyb(bool_type(True)) + assert keyb(bool_type(False)) == keyb(bool_type(False)) + assert keyb(bool_type(True)) != keyb(bool_type(False)) + + assert keyb(bool_type) != keyb(np.dtype(bool_type)) + assert keyb(bool_type(True)) != keyb(np.dtype(bool_type(True))) + assert keyb(bool_type(False)) != keyb(np.dtype(bool_type(False))) + + def test_scalar_hashing() -> None: keyb = KeyBuilder() -- GitLab