diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c29a01fae3d831562f39d308584ed45f2fb05ec8..a88e66a71e977f6d54ed0108f06edc98629a4cfe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -83,7 +83,7 @@ jobs: # AK, 2020-12-13 rm pytools/mpiwrap.py - EXTRA_INSTALL="numpy frozendict immutabledict orderedsets constantdict immutables pyrsistent" + EXTRA_INSTALL="numpy frozendict immutabledict orderedsets constantdict immutables pyrsistent attrs" curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project.sh . ./build-and-test-py-project.sh diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 421e02a1d72fd099fcd20723e7461cba95a9074e..47d2a8c0130f4eaf8a8308c2831b6065c89ffdcd 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -381,14 +381,14 @@ class KeyBuilder: key_hash.update(np.array(key).tobytes()) def update_for_dataclass(self, key_hash, key): - self.rec(key_hash, type(key_hash).__name__.encode("utf-8")) + self.rec(key_hash, f"{type(key).__qualname__}.{type(key).__name__}") for fld in dc_fields(key): self.rec(key_hash, fld.name) self.rec(key_hash, getattr(key, fld.name, None)) def update_for_attrs(self, key_hash, key): - self.rec(key_hash, type(key_hash).__name__.encode("utf-8")) + self.rec(key_hash, f"{type(key).__qualname__}.{type(key).__name__}") for fld in attrs.fields(key.__class__): self.rec(key_hash, fld.name) diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index eb22ec65b9e44b26a9fe3a5aa1abe733e413eb40..3c852b1bbbc4ca65edc49be9c50a4a0642995dcd 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -518,6 +518,61 @@ def test_class_hashing(): assert keyb(TagClass2) != keyb(TagClass2()) +def test_dataclass_hashing(): + keyb = KeyBuilder() + + @dataclass + class MyDC: + name: str + value: int + + assert keyb(MyDC("hi", 1)) == \ + "2ba6363c3b98f1cc2209bd57388368b3efe3074e3764eee30fbcf15946efb802" + + assert keyb(MyDC("hi", 1)) == keyb(MyDC("hi", 1)) + assert keyb(MyDC("hi", 1)) != keyb(MyDC("hi", 2)) + + @dataclass + class MyDC2: + name: str + value: int + + # Class types must be encoded in hash + assert keyb(MyDC2("hi", 1)) != keyb(MyDC("hi", 1)) + + +def test_attrs_hashing(): + attrs = pytest.importorskip("attrs") + + keyb = KeyBuilder() + + @attrs.define + class MyAttrs: + name: str + value: int + + assert keyb(MyAttrs("hi", 1)) == \ + "17f272d114d22c1dc0117354777f2d506b303d90e10840d39fb0eef007252f68" + + assert keyb(MyAttrs("hi", 1)) == keyb(MyAttrs("hi", 1)) + assert keyb(MyAttrs("hi", 1)) != keyb(MyAttrs("hi", 2)) + + @dataclass + class MyDC: + name: str + value: int + + assert keyb(MyDC("hi", 1)) != keyb(MyAttrs("hi", 1)) + + @attrs.define + class MyAttrs2: + name: str + value: int + + # Class types must be encoded in hash + assert keyb(MyAttrs2("hi", 1)) != keyb(MyAttrs("hi", 1)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])