From d9d009b981a75f888b4bcac32a4e0c8637ef1294 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 10 Mar 2021 16:17:28 -0600
Subject: [PATCH] unordered_hash: update a hash instance, instead of forcing
 creation of a new one

---
 pytools/__init__.py        | 17 +++++++++++++----
 pytools/persistent_dict.py |  8 +++-----
 test/test_pytools.py       | 16 ++++++++--------
 3 files changed, 24 insertions(+), 17 deletions(-)

diff --git a/pytools/__init__.py b/pytools/__init__.py
index c8cfa90..f1f1ebf 100644
--- a/pytools/__init__.py
+++ b/pytools/__init__.py
@@ -2623,7 +2623,7 @@ def resolve_name(name):
 
 # {{{ unordered_hash
 
-def unordered_hash(hash_constructor, iterable):
+def unordered_hash(hash_instance, iterable, hash_constructor=None):
     """Using a hash algorithm given by the parameter-less constructor
     *hash_constructor*, return a hash object whose internal state
     depends on the entries of *iterable*, but not their order. If *hash*
@@ -2631,6 +2631,10 @@ def unordered_hash(hash_constructor, iterable):
     the each entry *i* of the iterable must permit ``hash.upate(i)`` to
     succeed. An example of *hash_constructor* is ``hashlib.sha256``
     from :mod:`hashlib`.  ``hash.digest_size`` must also be defined.
+    If *hash_constructor* is not provided, ``hash_instance.name`` is
+    used to deduce it.
+
+    :returns: the updated *hash_instance*.
 
     .. warning::
 
@@ -2639,6 +2643,12 @@ def unordered_hash(hash_constructor, iterable):
 
     .. versionadded:: 2021.2
     """
+
+    if hash_constructor is None:
+        from functools import partial
+        import hashlib
+        hash_constructor = partial(hashlib.new, hash_instance.name)
+
     h_int = 0
     for i in iterable:
         h_i = hash_constructor()
@@ -2650,9 +2660,8 @@ def unordered_hash(hash_constructor, iterable):
         # mix adjacent bits).
         h_int = h_int ^ int.from_bytes(h_i.digest(), sys.byteorder)
 
-    h = hash_constructor()
-    h.update(h_int.to_bytes(h.digest_size, sys.byteorder))
-    return h
+    hash_instance.update(h_int.to_bytes(hash_instance.digest_size, sys.byteorder))
+    return hash_instance
 
 # }}}
 
diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py
index f48027f..c3a7b1d 100644
--- a/pytools/persistent_dict.py
+++ b/pytools/persistent_dict.py
@@ -294,11 +294,9 @@ class KeyBuilder:
     def update_for_frozenset(self, key_hash, key):
         from pytools import unordered_hash
 
-        self.rec(key_hash,
-                unordered_hash(
-                    self.new_hash,
-                    (self.rec(self.new_hash(), key_i).digest() for key_i in key)
-                    ).digest())
+        unordered_hash(
+            key_hash,
+            (self.rec(self.new_hash(), key_i).digest() for key_i in key))
 
     @staticmethod
     def update_for_NoneType(key_hash, key):  # noqa
diff --git a/test/test_pytools.py b/test/test_pytools.py
index 49a4390..356e47c 100644
--- a/test/test_pytools.py
+++ b/test/test_pytools.py
@@ -446,15 +446,15 @@ def test_unordered_hash():
     random.shuffle(lst)
 
     from pytools import unordered_hash
-    assert (unordered_hash(hashlib.sha256, lorig).digest()
-            == unordered_hash(hashlib.sha256, lst).digest())
-    assert (unordered_hash(hashlib.sha256, lorig).digest()
-            == unordered_hash(hashlib.sha256, lorig).digest())
-    assert (unordered_hash(hashlib.sha256, lorig).digest()
-            != unordered_hash(hashlib.sha256, lorig[:-1]).digest())
+    assert (unordered_hash(hashlib.sha256(), lorig).digest()
+            == unordered_hash(hashlib.sha256(), lst).digest())
+    assert (unordered_hash(hashlib.sha256(), lorig).digest()
+            == unordered_hash(hashlib.sha256(), lorig).digest())
+    assert (unordered_hash(hashlib.sha256(), lorig).digest()
+            != unordered_hash(hashlib.sha256(), lorig[:-1]).digest())
     lst[0] = b"aksdjfla;sdfjafd"
-    assert (unordered_hash(hashlib.sha256, lorig).digest()
-            != unordered_hash(hashlib.sha256, lst).digest())
+    assert (unordered_hash(hashlib.sha256(), lorig).digest()
+            != unordered_hash(hashlib.sha256(), lst).digest())
 
 
 if __name__ == "__main__":
-- 
GitLab