From 2e9cf0d2b6f67bf99897a89b646b652fbd4df05e Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Tue, 12 Nov 2024 14:51:55 -0600
Subject: [PATCH] KeyBuilder: support np.bool{,_} for numpy<2

---
 pytools/persistent_dict.py           |  2 +-
 pytools/test/test_persistent_dict.py | 25 +++++++++++++++++++++++++
 2 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py
index 62a7f5d..5d830bd 100644
--- a/pytools/persistent_dict.py
+++ b/pytools/persistent_dict.py
@@ -221,7 +221,7 @@ class KeyBuilder:
                             method = self.update_for_specific_dtype
 
                     # Hashing numpy scalars
-                    elif isinstance(key, np.number):
+                    elif isinstance(key, np.number | np.bool_):
                         # Non-numpy scalars are handled above in the try block.
                         method = self.update_for_numpy_scalar
 
diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py
index 727e589..ed41c5d 100644
--- a/pytools/test/test_persistent_dict.py
+++ b/pytools/test/test_persistent_dict.py
@@ -414,6 +414,31 @@ def test_dtype_hashing() -> None:
     assert keyb(np.dtype(np.float32)) == keyb(np.dtype(np.float32))
 
 
+def test_bool_hashing() -> None:
+    keyb = KeyBuilder()
+
+    assert keyb(True) == keyb(True)
+    assert keyb(False) == keyb(False)
+    assert keyb(True) != keyb(False)
+
+    np = pytest.importorskip("numpy")
+
+    bool_types = [np.bool_]
+    if hasattr(np, "bool"):
+        bool_types.append(np.bool)
+
+    for bool_type in bool_types:
+        assert keyb(bool_type) != keyb(bool)
+
+        assert keyb(bool_type(True)) == keyb(bool_type(True))
+        assert keyb(bool_type(False)) == keyb(bool_type(False))
+        assert keyb(bool_type(True)) != keyb(bool_type(False))
+
+        assert keyb(bool_type) != keyb(np.dtype(bool_type))
+        assert keyb(bool_type(True)) != keyb(np.dtype(bool_type(True)))
+        assert keyb(bool_type(False)) != keyb(np.dtype(bool_type(False)))
+
+
 def test_scalar_hashing() -> None:
     keyb = KeyBuilder()
 
-- 
GitLab