diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 895aff64ba7f0829817251f4119f3963d91345ec..272340724bafb92fa3fe6c0ba8c8c3c09a235dab 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -27,16 +27,22 @@ THE SOFTWARE. """ import collections.abc as abc +import errno import hashlib import logging -from dataclasses import fields, is_dataclass -from enum import Enum - - -import errno import os import shutil import sys +from dataclasses import fields as dc_fields, is_dataclass +from enum import Enum + + +try: + import attrs +except ModuleNotFoundError: + _HAS_ATTRS = False +else: + _HAS_ATTRS = True logger = logging.getLogger(__name__) @@ -250,6 +256,9 @@ class KeyBuilder: elif is_dataclass(tp): method = self.update_for_dataclass + elif _HAS_ATTRS and attrs.has(tp): + method = self.update_for_attrs + if method is not None: inner_key_hash = self.new_hash() method(inner_key_hash, key) @@ -344,7 +353,14 @@ class KeyBuilder: def update_for_dataclass(self, key_hash, key): self.rec(key_hash, type(key_hash).__name__.encode("utf-8")) - for fld in fields(key): + for fld in dc_fields(key): + self.rec(key_hash, fld.name) + self.rec(key_hash, getattr(key, fld.name, None)) + + def update_for_attrs(self, key_hash, key): + self.rec(key_hash, type(key_hash).__name__.encode("utf-8")) + + for fld in attrs.fields(key.__class__): self.rec(key_hash, fld.name) self.rec(key_hash, getattr(key, fld.name, None))