From c0062855e1cdf07ebe1ed37d5bb14fb48af4d0f7 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 27 Jun 2021 00:50:19 -0500
Subject: [PATCH] Put call_loopy in charge of the transform cache

---
 arraycontext/impl/pyopencl/__init__.py | 10 +++-------
 1 file changed, 3 insertions(+), 7 deletions(-)

diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index b29fd33..0431902 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -167,7 +167,10 @@ class PyOpenCLArrayContext(ArrayContext):
         try:
             t_unit = self._loopy_transform_cache[t_unit]
         except KeyError:
+            orig_t_unit = t_unit
             t_unit = self.transform_loopy_program(t_unit)
+            self._loopy_transform_cache[orig_t_unit] = t_unit
+            del orig_t_unit
 
         evt, result = t_unit(self.queue, **kwargs, allocator=self.allocator)
 
@@ -192,12 +195,6 @@ class PyOpenCLArrayContext(ArrayContext):
     # }}}
 
     def transform_loopy_program(self, t_unit):
-        try:
-            return self._loopy_transform_cache[t_unit]
-        except KeyError:
-            pass
-        orig_t_unit = t_unit
-
         from warnings import warn
         warn("Using arraycontext.PyOpenCLArrayContext.transform_loopy_program "
                 "to transform a program. This is deprecated and will stop working "
@@ -259,7 +256,6 @@ class PyOpenCLArrayContext(ArrayContext):
             t_unit = lp.split_iname(t_unit, inner_iname, 16, inner_tag="l.0")
         t_unit = lp.tag_inames(t_unit, {outer_iname: "g.0"})
 
-        self._loopy_transform_cache[orig_t_unit] = t_unit
         return t_unit
 
     def tag(self, tags: Union[Sequence[Tag], Tag], array):
-- 
GitLab