From 1860d17c5fcdb7f69a7cedc89765a559b22f45b4 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 14 Jun 2021 01:20:15 -0500 Subject: [PATCH] Inline the transform cache handling in PyOpenCLArrayContext.call_loopy --- arraycontext/impl/pyopencl.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/arraycontext/impl/pyopencl.py b/arraycontext/impl/pyopencl.py index 1f9c467..0cd3f64 100644 --- a/arraycontext/impl/pyopencl.py +++ b/arraycontext/impl/pyopencl.py @@ -33,7 +33,6 @@ import operator import numpy as np -from pytools import memoize_method from pytools.tag import Tag from arraycontext.metadata import FirstAxisIsElementsTag @@ -288,6 +287,8 @@ class PyOpenCLArrayContext(ArrayContext): "are running Python in debug mode. Use 'python -O' for " "a noticeable speed improvement.") + self._loopy_transform_cache = {} + def _get_fake_numpy_namespace(self): return PyOpenCLFakeNumpyNamespace(self) @@ -311,14 +312,15 @@ class PyOpenCLArrayContext(ArrayContext): return array.get(queue=self.queue) def call_loopy(self, t_unit, **kwargs): - t_unit = self.transform_loopy_program(t_unit) - from arraycontext.loopy import get_default_entrypoint - default_entrypoint = get_default_entrypoint(t_unit) - prg_name = default_entrypoint.name + try: + t_unit = self._loopy_transform_cache[t_unit] + except KeyError: + t_unit = self.transform_loopy_program(t_unit) evt, result = t_unit(self.queue, **kwargs, allocator=self.allocator) if self._wait_event_queue_length is not False: + prg_name = t_unit.default_entrypoint.name wait_event_queue = self._kernel_name_to_wait_event_queue.setdefault( prg_name, []) @@ -337,13 +339,17 @@ class PyOpenCLArrayContext(ArrayContext): # }}} - @memoize_method def transform_loopy_program(self, t_unit): + try: + return self._loopy_transform_cache[t_unit] + except KeyError: + pass + orig_t_unit = t_unit + # accommodate loopy with and without kernel callables import loopy as lp - from arraycontext.loopy import get_default_entrypoint - default_entrypoint = get_default_entrypoint(t_unit) + default_entrypoint = t_unit.default_entrypoint options = default_entrypoint.options if not (options.return_dict and options.no_numpy): raise ValueError("Loopy kernel passed to call_loopy must " @@ -385,7 +391,10 @@ class PyOpenCLArrayContext(ArrayContext): if inner_iname is not None: t_unit = lp.split_iname(t_unit, inner_iname, 16, inner_tag="l.0") - return lp.tag_inames(t_unit, {outer_iname: "g.0"}) + t_unit = lp.tag_inames(t_unit, {outer_iname: "g.0"}) + + self._loopy_transform_cache[orig_t_unit] = t_unit + return t_unit def tag(self, tags: Union[Sequence[Tag], Tag], array): # Sorry, not capable. -- GitLab