From bda06579bc86a268abb4ca94b7a16c3cfa07f27f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 18 Jun 2023 09:04:24 -0500 Subject: [PATCH] Support attrs classes for persistent hashing --- pytools/persistent_dict.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 895aff6..2723407 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -27,16 +27,22 @@ THE SOFTWARE. """ import collections.abc as abc +import errno import hashlib import logging -from dataclasses import fields, is_dataclass -from enum import Enum - - -import errno import os import shutil import sys +from dataclasses import fields as dc_fields, is_dataclass +from enum import Enum + + +try: + import attrs +except ModuleNotFoundError: + _HAS_ATTRS = False +else: + _HAS_ATTRS = True logger = logging.getLogger(__name__) @@ -250,6 +256,9 @@ class KeyBuilder: elif is_dataclass(tp): method = self.update_for_dataclass + elif _HAS_ATTRS and attrs.has(tp): + method = self.update_for_attrs + if method is not None: inner_key_hash = self.new_hash() method(inner_key_hash, key) @@ -344,7 +353,14 @@ class KeyBuilder: def update_for_dataclass(self, key_hash, key): self.rec(key_hash, type(key_hash).__name__.encode("utf-8")) - for fld in fields(key): + for fld in dc_fields(key): + self.rec(key_hash, fld.name) + self.rec(key_hash, getattr(key, fld.name, None)) + + def update_for_attrs(self, key_hash, key): + self.rec(key_hash, type(key_hash).__name__.encode("utf-8")) + + for fld in attrs.fields(key.__class__): self.rec(key_hash, fld.name) self.rec(key_hash, getattr(key, fld.name, None)) -- GitLab