From 0cbb5c1c5425f498c3569e826bc5f0becdc06016 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 7 Mar 2021 18:30:57 -0600
Subject: [PATCH] Support pytools.tag in persistent_dict

---
 pytools/tag.py               | 11 +++++++++++
 test/test_persistent_dict.py | 10 +++++++++-
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/pytools/tag.py b/pytools/tag.py
index b943d95..ca541ad 100644
--- a/pytools/tag.py
+++ b/pytools/tag.py
@@ -122,6 +122,17 @@ class Tag:
     def tag_name(self) -> DottedName:
         return DottedName.from_class(type(self))
 
+    def update_persistent_hash(self, key_hash, key_builder):
+        key_builder.rec(key_hash, self.__class__.__qualname__)
+
+        from dataclasses import fields
+        # Fields are ordered consistently, so ordered hashing is OK.
+        #
+        # No need to dispatch to superclass: fields() automatically gives us
+        # fields from the entire class hierarchy.
+        for f in fields(self):
+            key_builder.rec(key_hash, getattr(self, f.name))
+
 # }}}
 
 
diff --git a/test/test_persistent_dict.py b/test/test_persistent_dict.py
index 9fe775b..bf1fd54 100644
--- a/test/test_persistent_dict.py
+++ b/test/test_persistent_dict.py
@@ -4,6 +4,7 @@ import tempfile
 
 import pytest
 
+from pytools.tag import Tag, tag_dataclass
 from pytools.persistent_dict import (CollisionWarning, NoSuchEntryError,
         PersistentDict, ReadOnlyEntryError, WriteOncePersistentDict)
 
@@ -39,6 +40,11 @@ class PDictTestingKeyOrValue:
 # }}}
 
 
+@tag_dataclass
+class SomeTag(Tag):
+    value: str
+
+
 def test_persistent_dict_storage_and_lookup():
     try:
         tmpdir = tempfile.mkdtemp()
@@ -51,7 +57,9 @@ def test_persistent_dict_storage_and_lookup():
                     chr(65+randrange(26))
                     for i in range(n))
 
-        keys = [(randrange(2000), rand_str(), None) for i in range(20)]
+        keys = [
+                (randrange(2000), rand_str(), None, SomeTag(rand_str()))
+                for i in range(20)]
         values = [randrange(2000) for i in range(20)]
 
         d = dict(list(zip(keys, values)))
-- 
GitLab