From 222320481a40ef4e0152f028975bfd4964b9393c Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Mon, 6 Nov 2023 14:27:04 -0600
Subject: [PATCH] PersistentDict: support frozendict, immutabledict,
 FrozenOrderedSet

---
 .github/workflows/ci.yml             |  2 +-
 pytools/persistent_dict.py           | 11 +++++++
 pytools/test/test_persistent_dict.py | 49 ++++++++++++++++++++++++++++
 3 files changed, 61 insertions(+), 1 deletion(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 6c9f04b..4149996 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -83,7 +83,7 @@ jobs:
                 # AK, 2020-12-13
                 rm pytools/mpiwrap.py
 
-                EXTRA_INSTALL="numpy"
+                EXTRA_INSTALL="numpy frozendict immutabledict orderedsets"
                 curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh
                 . ./build-and-test-py-project.sh
 
diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py
index 6a9f0dd..65fb6e9 100644
--- a/pytools/persistent_dict.py
+++ b/pytools/persistent_dict.py
@@ -346,6 +346,8 @@ class KeyBuilder:
             key_hash,
             (self.rec(self.new_hash(), key_i).digest() for key_i in key))
 
+    update_for_FrozenOrderedSet = update_for_frozenset  # noqa: N815
+
     @staticmethod
     def update_for_NoneType(key_hash, key):  # noqa
         del key
@@ -387,6 +389,15 @@ class KeyBuilder:
             self.rec(key_hash, fld.name)
             self.rec(key_hash, getattr(key, fld.name, None))
 
+    def update_for_frozendict(self, key_hash, key):
+        from pytools import unordered_hash
+
+        unordered_hash(
+            key_hash,
+            (self.rec(self.new_hash(), (k, v)).digest() for k, v in key.items()))
+
+    update_for_immutabledict = update_for_frozendict
+
     # }}}
 
 # }}}
diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py
index 754b652..2c3aeab 100644
--- a/pytools/test/test_persistent_dict.py
+++ b/pytools/test/test_persistent_dict.py
@@ -421,6 +421,55 @@ def test_scalar_hashing():
     assert keyb(np.clongdouble(1.1+2.2j)) == keyb(np.clongdouble(1.1+2.2j))
 
 
+def test_frozendict_hashing():
+    pytest.importorskip("frozendict")
+    from frozendict import frozendict
+
+    keyb = KeyBuilder()
+
+    d = {"a": 1, "b": 2}
+
+    assert keyb(frozendict(d)) == keyb(frozendict(d))
+    assert keyb(frozendict(d)) != keyb(frozendict({"a": 1, "b": 3}))
+    assert keyb(frozendict(d)) == keyb(frozendict({"b": 2, "a": 1}))
+
+    with pytest.raises(TypeError):
+        keyb(d)
+
+
+def test_immutabledict_hashing():
+    pytest.importorskip("immutabledict")
+    from immutabledict import immutabledict
+
+    keyb = KeyBuilder()
+
+    d = {"a": 1, "b": 2}
+
+    assert keyb(immutabledict(d)) == keyb(immutabledict(d))
+    assert keyb(immutabledict(d)) != keyb(immutabledict({"a": 1, "b": 3}))
+    assert keyb(immutabledict(d)) == keyb(immutabledict({"b": 2, "a": 1}))
+
+
+def test_frozenset_hashing():
+    keyb = KeyBuilder()
+
+    assert keyb(frozenset([1, 2, 3])) == keyb(frozenset([1, 2, 3]))
+    assert keyb(frozenset([1, 2, 3])) != keyb(frozenset([1, 2, 4]))
+    assert keyb(frozenset([1, 2, 3])) == keyb(frozenset([3, 2, 1]))
+
+
+def test_frozenorderedset_hashing():
+    pytest.importorskip("orderedsets")
+    from orderedsets import FrozenOrderedSet
+    keyb = KeyBuilder()
+
+    assert (keyb(FrozenOrderedSet([1, 2, 3]))
+            == keyb(FrozenOrderedSet([1, 2, 3]))
+            == keyb(frozenset([1, 2, 3])))
+    assert keyb(FrozenOrderedSet([1, 2, 3])) != keyb(FrozenOrderedSet([1, 2, 4]))
+    assert keyb(FrozenOrderedSet([1, 2, 3])) == keyb(FrozenOrderedSet([3, 2, 1]))
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab