From 9b7559032d8071ccdd482edbf69b1f5e9da9b69b Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 4 Sep 2024 14:57:54 -0500
Subject: [PATCH] Numpy actx: cache execuctor

---
 arraycontext/context.py             |  2 +-
 arraycontext/impl/numpy/__init__.py | 27 +++++++++++++++++----------
 2 files changed, 18 insertions(+), 11 deletions(-)

diff --git a/arraycontext/context.py b/arraycontext/context.py
index d296f8f..30f58cb 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -339,7 +339,7 @@ class ArrayContext(ABC):
 
     @abstractmethod
     def call_loopy(self,
-                   program: "loopy.TranslationUnit",
+                   t_unit: "loopy.TranslationUnit",
                    **kwargs: Any) -> Dict[str, Array]:
         """Execute the :mod:`loopy` program *program* on the arguments
         *kwargs*.
diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py
index 77b7b49..f8ba95e 100644
--- a/arraycontext/impl/numpy/__init__.py
+++ b/arraycontext/impl/numpy/__init__.py
@@ -1,3 +1,6 @@
+from __future__ import annotations
+
+
 """
 .. currentmodule:: arraycontext
 
@@ -30,7 +33,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Any, Dict
+from typing import Any
 
 import numpy as np
 
@@ -39,6 +42,7 @@ from pytools.tag import ToTagSetConvertible
 
 from arraycontext.container.traversal import rec_map_array_container, with_array_context
 from arraycontext.context import (
+    Array,
     ArrayContext,
     ArrayOrContainerOrScalar,
     ArrayOrContainerOrScalarT,
@@ -62,10 +66,12 @@ class NumpyArrayContext(ArrayContext):
 
     .. automethod:: __init__
     """
+
+    _loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase]
+
     def __init__(self) -> None:
         super().__init__()
-        self._loopy_transform_cache: \
-                Dict[lp.TranslationUnit, lp.TranslationUnit] = {}
+        self._loopy_transform_cache = {}
 
     array_types = (NumpyNonObjectArray,)
 
@@ -88,17 +94,18 @@ class NumpyArrayContext(ArrayContext):
                  ) -> NumpyOrContainerOrScalar:
         return array
 
-    def call_loopy(self, t_unit, **kwargs):
+    def call_loopy(
+                self,
+                t_unit: lp.TranslationUnit, **kwargs: Any
+            ) -> dict[str, Array]:
         t_unit = t_unit.copy(target=lp.ExecutableCTarget())
         try:
-            t_unit = self._loopy_transform_cache[t_unit]
+            executor = self._loopy_transform_cache[t_unit]
         except KeyError:
-            orig_t_unit = t_unit
-            t_unit = self.transform_loopy_program(t_unit)
-            self._loopy_transform_cache[orig_t_unit] = t_unit
-            del orig_t_unit
+            executor = self.transform_loopy_program(t_unit).executor()
+            self._loopy_transform_cache[t_unit] = executor
 
-        _, result = t_unit(**kwargs)
+        _, result = executor(**kwargs)
 
         return result
 
-- 
GitLab