diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py index 74eadb39b8a7f14b7b92775954e66147b87413ca..068d86b73873ff052e7620c318dad3331aead808 100644 --- a/arraycontext/impl/pyopencl.py +++ b/arraycontext/impl/pyopencl.py @@ -33,7 +33,6 @@ import operator import numpy as np -from pytools import memoize_method from pytools.tag import Tag from arraycontext.metadata import FirstAxisIsElementsTag @@ -241,6 +240,8 @@ class PyOpenCLArrayContext(ArrayContext): "are running Python in debug mode. Use 'python -O' for " "a noticeable speed improvement.") + self._loopy_transform_cache = {} + def _get_fake_numpy_namespace(self): return PyOpenCLFakeNumpyNamespace(self) @@ -264,14 +265,15 @@ class PyOpenCLArrayContext(ArrayContext): return array.get(queue=self.queue) def call_loopy(self, t_unit, **kwargs): - t_unit = self.transform_loopy_program(t_unit) - from arraycontext.loopy import get_default_entrypoint - default_entrypoint = get_default_entrypoint(t_unit) - prg_name = default_entrypoint.name + try: + t_unit = self._loopy_transform_cache[t_unit] + except KeyError: + t_unit = self.transform_loopy_program(t_unit) evt, result = t_unit(self.queue, **kwargs, allocator=self.allocator) if self._wait_event_queue_length is not False: + prg_name = t_unit.default_entrypoint.name wait_event_queue = self._kernel_name_to_wait_event_queue.setdefault( prg_name, []) @@ -290,13 +292,17 @@ class PyOpenCLArrayContext(ArrayContext): # }}} - @memoize_method def transform_loopy_program(self, t_unit): + try: + return self._loopy_transform_cache[t_unit] + except KeyError: + pass + orig_t_unit = t_unit + # accommodate loopy with and without kernel callables import loopy as lp - from arraycontext.loopy import get_default_entrypoint - default_entrypoint = get_default_entrypoint(t_unit) + default_entrypoint = t_unit.default_entrypoint options = default_entrypoint.options if not (options.return_dict and options.no_numpy): raise ValueError("Loopy kernel passed to call_loopy must " @@ -337,7 +343,10 @@ class PyOpenCLArrayContext(ArrayContext): if inner_iname is not None: t_unit = lp.split_iname(t_unit, inner_iname, 16, inner_tag="l.0") - return lp.tag_inames(t_unit, {outer_iname: "g.0"}) + t_unit = lp.tag_inames(t_unit, {outer_iname: "g.0"}) + + self._loopy_transform_cache[orig_t_unit] = t_unit + return t_unit def tag(self, tags: Union[Sequence[Tag], Tag], array): # Sorry, not capable.