diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6c9f04bdf48bd6280f3c82f71614826d2474c7e9..4149996467f0e65c3f277c858fef2670bbc7fd1c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -83,7 +83,7 @@ jobs: # AK, 2020-12-13 rm pytools/mpiwrap.py - EXTRA_INSTALL="numpy" + EXTRA_INSTALL="numpy frozendict immutabledict orderedsets" curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh . ./build-and-test-py-project.sh diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 6a9f0ddafcb58ce27060ec4592259eadd2b726d6..65fb6e9c56e3ef3d24e26924a6d5f855900250e3 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -346,6 +346,8 @@ class KeyBuilder: key_hash, (self.rec(self.new_hash(), key_i).digest() for key_i in key)) + update_for_FrozenOrderedSet = update_for_frozenset # noqa: N815 + @staticmethod def update_for_NoneType(key_hash, key): # noqa del key @@ -387,6 +389,15 @@ class KeyBuilder: self.rec(key_hash, fld.name) self.rec(key_hash, getattr(key, fld.name, None)) + def update_for_frozendict(self, key_hash, key): + from pytools import unordered_hash + + unordered_hash( + key_hash, + (self.rec(self.new_hash(), (k, v)).digest() for k, v in key.items())) + + update_for_immutabledict = update_for_frozendict + # }}} # }}} diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index 754b65247eb538603d144f7602a90192969543ce..2c3aeabd54d9cb74b83c587dde116c9fb091fa88 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -421,6 +421,55 @@ def test_scalar_hashing(): assert keyb(np.clongdouble(1.1+2.2j)) == keyb(np.clongdouble(1.1+2.2j)) +def test_frozendict_hashing(): + pytest.importorskip("frozendict") + from frozendict import frozendict + + keyb = KeyBuilder() + + d = {"a": 1, "b": 2} + + assert keyb(frozendict(d)) == keyb(frozendict(d)) + assert keyb(frozendict(d)) != keyb(frozendict({"a": 1, "b": 3})) + assert keyb(frozendict(d)) == keyb(frozendict({"b": 2, "a": 1})) + + with pytest.raises(TypeError): + keyb(d) + + +def test_immutabledict_hashing(): + pytest.importorskip("immutabledict") + from immutabledict import immutabledict + + keyb = KeyBuilder() + + d = {"a": 1, "b": 2} + + assert keyb(immutabledict(d)) == keyb(immutabledict(d)) + assert keyb(immutabledict(d)) != keyb(immutabledict({"a": 1, "b": 3})) + assert keyb(immutabledict(d)) == keyb(immutabledict({"b": 2, "a": 1})) + + +def test_frozenset_hashing(): + keyb = KeyBuilder() + + assert keyb(frozenset([1, 2, 3])) == keyb(frozenset([1, 2, 3])) + assert keyb(frozenset([1, 2, 3])) != keyb(frozenset([1, 2, 4])) + assert keyb(frozenset([1, 2, 3])) == keyb(frozenset([3, 2, 1])) + + +def test_frozenorderedset_hashing(): + pytest.importorskip("orderedsets") + from orderedsets import FrozenOrderedSet + keyb = KeyBuilder() + + assert (keyb(FrozenOrderedSet([1, 2, 3])) + == keyb(FrozenOrderedSet([1, 2, 3])) + == keyb(frozenset([1, 2, 3]))) + assert keyb(FrozenOrderedSet([1, 2, 3])) != keyb(FrozenOrderedSet([1, 2, 4])) + assert keyb(FrozenOrderedSet([1, 2, 3])) == keyb(FrozenOrderedSet([3, 2, 1])) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])