diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index b29fd3380332d5403564f11a74ce8e630ebc92f0..04319027d8a8d2415fdd0edb41063415356bca3c 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -167,7 +167,10 @@ class PyOpenCLArrayContext(ArrayContext): try: t_unit = self._loopy_transform_cache[t_unit] except KeyError: + orig_t_unit = t_unit t_unit = self.transform_loopy_program(t_unit) + self._loopy_transform_cache[orig_t_unit] = t_unit + del orig_t_unit evt, result = t_unit(self.queue, **kwargs, allocator=self.allocator) @@ -192,12 +195,6 @@ class PyOpenCLArrayContext(ArrayContext): # }}} def transform_loopy_program(self, t_unit): - try: - return self._loopy_transform_cache[t_unit] - except KeyError: - pass - orig_t_unit = t_unit - from warnings import warn warn("Using arraycontext.PyOpenCLArrayContext.transform_loopy_program " "to transform a program. This is deprecated and will stop working " @@ -259,7 +256,6 @@ class PyOpenCLArrayContext(ArrayContext): t_unit = lp.split_iname(t_unit, inner_iname, 16, inner_tag="l.0") t_unit = lp.tag_inames(t_unit, {outer_iname: "g.0"}) - self._loopy_transform_cache[orig_t_unit] = t_unit return t_unit def tag(self, tags: Union[Sequence[Tag], Tag], array):