From f0846691d9b2c3575583b93570cff5fe237a794c Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Tue, 21 May 2024 14:22:51 -0500
Subject: [PATCH] support closures

---
 pytools/persistent_dict.py           |  3 +++
 pytools/test/test_persistent_dict.py | 27 +++++++++++++++++++++++++++
 2 files changed, 30 insertions(+)

diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py
index 13d8b10..60131d8 100644
--- a/pytools/persistent_dict.py
+++ b/pytools/persistent_dict.py
@@ -483,6 +483,9 @@ class KeyBuilder:
     def update_for_function(self, key_hash: Hash, key: Any) -> None:
         self.rec(key_hash, key.__module__ + key.__qualname__)
 
+        if key.__closure__:
+            self.rec(key_hash, tuple(c.cell_contents for c in key.__closure__))
+
     # }}}
 
 # }}}
diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py
index 578b6ad..884d926 100644
--- a/pytools/test/test_persistent_dict.py
+++ b/pytools/test/test_persistent_dict.py
@@ -730,10 +730,31 @@ def global_fun2():
 def test_hash_function() -> None:
     keyb = KeyBuilder()
 
+    # {{{ global functions
+
     assert keyb(global_fun) == keyb(global_fun) == \
         "51b5980dd3a8aa13f6e83869e4a04c22973d7aaf96cb22899abdfdc55e15c9b2"
     assert keyb(global_fun) != keyb(global_fun2)
 
+    # }}}
+
+    # {{{ closures
+
+    def get_fun(x):
+        def add_x(y):
+            return x + y
+        return add_x
+
+    f1 = get_fun(1)
+    f2 = get_fun(2)
+
+    assert f1 != f2
+    assert keyb(f1) != keyb(f2)
+
+    # }}}
+
+    # {{{ local functions
+
     def local_fun():
         pass
 
@@ -744,6 +765,10 @@ def test_hash_function() -> None:
         "fc58f5b0130df821913c848749eb03f5dcd4da7a568c6130f1c0cfb96ed0d12d"
     assert keyb(local_fun) != keyb(local_fun2)
 
+    # }}}
+
+    # {{{ methods
+
     class C1:
         def method(self):
             pass
@@ -756,6 +781,8 @@ def test_hash_function() -> None:
         "3013eb424dac133a57bd70cb6084d2a2f349a247714efc508fe3b10b99b6f717"
     assert keyb(C1.method) != keyb(C2.method)
 
+    # }}}
+
 
 if __name__ == "__main__":
     if len(sys.argv) > 1:
-- 
GitLab