From 7e07c6922911499b64d687c769919a2c3308af60 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 21 Feb 2024 17:08:56 -0600 Subject: [PATCH] KeyBuilder: fix hashing of dataclasses/attrs, add test (#196) * KeyBuilder: fix hashing of dataclasses/attrs, add test * warn on missing attrs * install attrs * break out attrs test * add another small test --- .github/workflows/ci.yml | 2 +- pytools/persistent_dict.py | 4 +- pytools/test/test_persistent_dict.py | 55 ++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c29a01f..a88e66a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -83,7 +83,7 @@ jobs: # AK, 2020-12-13 rm pytools/mpiwrap.py - EXTRA_INSTALL="numpy frozendict immutabledict orderedsets constantdict immutables pyrsistent" + EXTRA_INSTALL="numpy frozendict immutabledict orderedsets constantdict immutables pyrsistent attrs" curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh . ./build-and-test-py-project.sh diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 421e02a..47d2a8c 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -381,14 +381,14 @@ class KeyBuilder: key_hash.update(np.array(key).tobytes()) def update_for_dataclass(self, key_hash, key): - self.rec(key_hash, type(key_hash).__name__.encode("utf-8")) + self.rec(key_hash, f"{type(key).__qualname__}.{type(key).__name__}") for fld in dc_fields(key): self.rec(key_hash, fld.name) self.rec(key_hash, getattr(key, fld.name, None)) def update_for_attrs(self, key_hash, key): - self.rec(key_hash, type(key_hash).__name__.encode("utf-8")) + self.rec(key_hash, f"{type(key).__qualname__}.{type(key).__name__}") for fld in attrs.fields(key.__class__): self.rec(key_hash, fld.name) diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index eb22ec6..3c852b1 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -518,6 +518,61 @@ def test_class_hashing(): assert keyb(TagClass2) != keyb(TagClass2()) +def test_dataclass_hashing(): + keyb = KeyBuilder() + + @dataclass + class MyDC: + name: str + value: int + + assert keyb(MyDC("hi", 1)) == \ + "2ba6363c3b98f1cc2209bd57388368b3efe3074e3764eee30fbcf15946efb802" + + assert keyb(MyDC("hi", 1)) == keyb(MyDC("hi", 1)) + assert keyb(MyDC("hi", 1)) != keyb(MyDC("hi", 2)) + + @dataclass + class MyDC2: + name: str + value: int + + # Class types must be encoded in hash + assert keyb(MyDC2("hi", 1)) != keyb(MyDC("hi", 1)) + + +def test_attrs_hashing(): + attrs = pytest.importorskip("attrs") + + keyb = KeyBuilder() + + @attrs.define + class MyAttrs: + name: str + value: int + + assert keyb(MyAttrs("hi", 1)) == \ + "17f272d114d22c1dc0117354777f2d506b303d90e10840d39fb0eef007252f68" + + assert keyb(MyAttrs("hi", 1)) == keyb(MyAttrs("hi", 1)) + assert keyb(MyAttrs("hi", 1)) != keyb(MyAttrs("hi", 2)) + + @dataclass + class MyDC: + name: str + value: int + + assert keyb(MyDC("hi", 1)) != keyb(MyAttrs("hi", 1)) + + @attrs.define + class MyAttrs2: + name: str + value: int + + # Class types must be encoded in hash + assert keyb(MyAttrs2("hi", 1)) != keyb(MyAttrs("hi", 1)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab