Skip to content
Snippets Groups Projects
Commit 1860d17c authored by Andreas Klöckner's avatar Andreas Klöckner Committed by Andreas Klöckner
Browse files

Inline the transform cache handling in PyOpenCLArrayContext.call_loopy

parent 82bb9386
No related branches found
No related tags found
No related merge requests found
......@@ -33,7 +33,6 @@ import operator
import numpy as np
from pytools import memoize_method
from pytools.tag import Tag
from arraycontext.metadata import FirstAxisIsElementsTag
......@@ -288,6 +287,8 @@ class PyOpenCLArrayContext(ArrayContext):
"are running Python in debug mode. Use 'python -O' for "
"a noticeable speed improvement.")
self._loopy_transform_cache = {}
def _get_fake_numpy_namespace(self):
return PyOpenCLFakeNumpyNamespace(self)
......@@ -311,14 +312,15 @@ class PyOpenCLArrayContext(ArrayContext):
return array.get(queue=self.queue)
def call_loopy(self, t_unit, **kwargs):
t_unit = self.transform_loopy_program(t_unit)
from arraycontext.loopy import get_default_entrypoint
default_entrypoint = get_default_entrypoint(t_unit)
prg_name = default_entrypoint.name
try:
t_unit = self._loopy_transform_cache[t_unit]
except KeyError:
t_unit = self.transform_loopy_program(t_unit)
evt, result = t_unit(self.queue, **kwargs, allocator=self.allocator)
if self._wait_event_queue_length is not False:
prg_name = t_unit.default_entrypoint.name
wait_event_queue = self._kernel_name_to_wait_event_queue.setdefault(
prg_name, [])
......@@ -337,13 +339,17 @@ class PyOpenCLArrayContext(ArrayContext):
# }}}
@memoize_method
def transform_loopy_program(self, t_unit):
try:
return self._loopy_transform_cache[t_unit]
except KeyError:
pass
orig_t_unit = t_unit
# accommodate loopy with and without kernel callables
import loopy as lp
from arraycontext.loopy import get_default_entrypoint
default_entrypoint = get_default_entrypoint(t_unit)
default_entrypoint = t_unit.default_entrypoint
options = default_entrypoint.options
if not (options.return_dict and options.no_numpy):
raise ValueError("Loopy kernel passed to call_loopy must "
......@@ -385,7 +391,10 @@ class PyOpenCLArrayContext(ArrayContext):
if inner_iname is not None:
t_unit = lp.split_iname(t_unit, inner_iname, 16, inner_tag="l.0")
return lp.tag_inames(t_unit, {outer_iname: "g.0"})
t_unit = lp.tag_inames(t_unit, {outer_iname: "g.0"})
self._loopy_transform_cache[orig_t_unit] = t_unit
return t_unit
def tag(self, tags: Union[Sequence[Tag], Tag], array):
# Sorry, not capable.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment