Skip to content
Snippets Groups Projects
Commit 1b138f29 authored by Matthias Diener's avatar Matthias Diener
Browse files

Merge remote-tracking branch 'origin/main' into pytato

parents 4a08cada 1860d17c
No related branches found
No related tags found
No related merge requests found
...@@ -33,7 +33,6 @@ import operator ...@@ -33,7 +33,6 @@ import operator
import numpy as np import numpy as np
from pytools import memoize_method
from pytools.tag import Tag from pytools.tag import Tag
from arraycontext.metadata import FirstAxisIsElementsTag from arraycontext.metadata import FirstAxisIsElementsTag
...@@ -241,6 +240,8 @@ class PyOpenCLArrayContext(ArrayContext): ...@@ -241,6 +240,8 @@ class PyOpenCLArrayContext(ArrayContext):
"are running Python in debug mode. Use 'python -O' for " "are running Python in debug mode. Use 'python -O' for "
"a noticeable speed improvement.") "a noticeable speed improvement.")
self._loopy_transform_cache = {}
def _get_fake_numpy_namespace(self): def _get_fake_numpy_namespace(self):
return PyOpenCLFakeNumpyNamespace(self) return PyOpenCLFakeNumpyNamespace(self)
...@@ -264,14 +265,15 @@ class PyOpenCLArrayContext(ArrayContext): ...@@ -264,14 +265,15 @@ class PyOpenCLArrayContext(ArrayContext):
return array.get(queue=self.queue) return array.get(queue=self.queue)
def call_loopy(self, t_unit, **kwargs): def call_loopy(self, t_unit, **kwargs):
t_unit = self.transform_loopy_program(t_unit) try:
from arraycontext.loopy import get_default_entrypoint t_unit = self._loopy_transform_cache[t_unit]
default_entrypoint = get_default_entrypoint(t_unit) except KeyError:
prg_name = default_entrypoint.name t_unit = self.transform_loopy_program(t_unit)
evt, result = t_unit(self.queue, **kwargs, allocator=self.allocator) evt, result = t_unit(self.queue, **kwargs, allocator=self.allocator)
if self._wait_event_queue_length is not False: if self._wait_event_queue_length is not False:
prg_name = t_unit.default_entrypoint.name
wait_event_queue = self._kernel_name_to_wait_event_queue.setdefault( wait_event_queue = self._kernel_name_to_wait_event_queue.setdefault(
prg_name, []) prg_name, [])
...@@ -290,13 +292,17 @@ class PyOpenCLArrayContext(ArrayContext): ...@@ -290,13 +292,17 @@ class PyOpenCLArrayContext(ArrayContext):
# }}} # }}}
@memoize_method
def transform_loopy_program(self, t_unit): def transform_loopy_program(self, t_unit):
try:
return self._loopy_transform_cache[t_unit]
except KeyError:
pass
orig_t_unit = t_unit
# accommodate loopy with and without kernel callables # accommodate loopy with and without kernel callables
import loopy as lp import loopy as lp
from arraycontext.loopy import get_default_entrypoint default_entrypoint = t_unit.default_entrypoint
default_entrypoint = get_default_entrypoint(t_unit)
options = default_entrypoint.options options = default_entrypoint.options
if not (options.return_dict and options.no_numpy): if not (options.return_dict and options.no_numpy):
raise ValueError("Loopy kernel passed to call_loopy must " raise ValueError("Loopy kernel passed to call_loopy must "
...@@ -337,7 +343,10 @@ class PyOpenCLArrayContext(ArrayContext): ...@@ -337,7 +343,10 @@ class PyOpenCLArrayContext(ArrayContext):
if inner_iname is not None: if inner_iname is not None:
t_unit = lp.split_iname(t_unit, inner_iname, 16, inner_tag="l.0") t_unit = lp.split_iname(t_unit, inner_iname, 16, inner_tag="l.0")
return lp.tag_inames(t_unit, {outer_iname: "g.0"}) t_unit = lp.tag_inames(t_unit, {outer_iname: "g.0"})
self._loopy_transform_cache[orig_t_unit] = t_unit
return t_unit
def tag(self, tags: Union[Sequence[Tag], Tag], array): def tag(self, tags: Union[Sequence[Tag], Tag], array):
# Sorry, not capable. # Sorry, not capable.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment