diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 62a7f5d16761b0482069c3429fb7f3e7fb302611..5d830bd8a6825a0754a950c7f4f501d0ed1cc5cd 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -221,7 +221,7 @@ class KeyBuilder: method = self.update_for_specific_dtype # Hashing numpy scalars - elif isinstance(key, np.number): + elif isinstance(key, np.number | np.bool_): # Non-numpy scalars are handled above in the try block. method = self.update_for_numpy_scalar diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index 727e589edf1bb54f189530daf10e7c3eb21dfeb6..ed41c5defeaa6990554b8658b331bb0e731aa546 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -414,6 +414,31 @@ def test_dtype_hashing() -> None: assert keyb(np.dtype(np.float32)) == keyb(np.dtype(np.float32)) +def test_bool_hashing() -> None: + keyb = KeyBuilder() + + assert keyb(True) == keyb(True) + assert keyb(False) == keyb(False) + assert keyb(True) != keyb(False) + + np = pytest.importorskip("numpy") + + bool_types = [np.bool_] + if hasattr(np, "bool"): + bool_types.append(np.bool) + + for bool_type in bool_types: + assert keyb(bool_type) != keyb(bool) + + assert keyb(bool_type(True)) == keyb(bool_type(True)) + assert keyb(bool_type(False)) == keyb(bool_type(False)) + assert keyb(bool_type(True)) != keyb(bool_type(False)) + + assert keyb(bool_type) != keyb(np.dtype(bool_type)) + assert keyb(bool_type(True)) != keyb(np.dtype(bool_type(True))) + assert keyb(bool_type(False)) != keyb(np.dtype(bool_type(False))) + + def test_scalar_hashing() -> None: keyb = KeyBuilder()