Skip to content
Snippets Groups Projects
Commit 9b755903 authored by Andreas Klöckner's avatar Andreas Klöckner Committed by Andreas Klöckner
Browse files

Numpy actx: cache execuctor

parent 15216d45
No related branches found
No related tags found
No related merge requests found
Pipeline #585656 failed
...@@ -339,7 +339,7 @@ class ArrayContext(ABC): ...@@ -339,7 +339,7 @@ class ArrayContext(ABC):
@abstractmethod @abstractmethod
def call_loopy(self, def call_loopy(self,
program: "loopy.TranslationUnit", t_unit: "loopy.TranslationUnit",
**kwargs: Any) -> Dict[str, Array]: **kwargs: Any) -> Dict[str, Array]:
"""Execute the :mod:`loopy` program *program* on the arguments """Execute the :mod:`loopy` program *program* on the arguments
*kwargs*. *kwargs*.
......
from __future__ import annotations
""" """
.. currentmodule:: arraycontext .. currentmodule:: arraycontext
...@@ -30,7 +33,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN ...@@ -30,7 +33,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
""" """
from typing import Any, Dict from typing import Any
import numpy as np import numpy as np
...@@ -39,6 +42,7 @@ from pytools.tag import ToTagSetConvertible ...@@ -39,6 +42,7 @@ from pytools.tag import ToTagSetConvertible
from arraycontext.container.traversal import rec_map_array_container, with_array_context from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import ( from arraycontext.context import (
Array,
ArrayContext, ArrayContext,
ArrayOrContainerOrScalar, ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT, ArrayOrContainerOrScalarT,
...@@ -62,10 +66,12 @@ class NumpyArrayContext(ArrayContext): ...@@ -62,10 +66,12 @@ class NumpyArrayContext(ArrayContext):
.. automethod:: __init__ .. automethod:: __init__
""" """
_loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase]
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._loopy_transform_cache: \ self._loopy_transform_cache = {}
Dict[lp.TranslationUnit, lp.TranslationUnit] = {}
array_types = (NumpyNonObjectArray,) array_types = (NumpyNonObjectArray,)
...@@ -88,17 +94,18 @@ class NumpyArrayContext(ArrayContext): ...@@ -88,17 +94,18 @@ class NumpyArrayContext(ArrayContext):
) -> NumpyOrContainerOrScalar: ) -> NumpyOrContainerOrScalar:
return array return array
def call_loopy(self, t_unit, **kwargs): def call_loopy(
self,
t_unit: lp.TranslationUnit, **kwargs: Any
) -> dict[str, Array]:
t_unit = t_unit.copy(target=lp.ExecutableCTarget()) t_unit = t_unit.copy(target=lp.ExecutableCTarget())
try: try:
t_unit = self._loopy_transform_cache[t_unit] executor = self._loopy_transform_cache[t_unit]
except KeyError: except KeyError:
orig_t_unit = t_unit executor = self.transform_loopy_program(t_unit).executor()
t_unit = self.transform_loopy_program(t_unit) self._loopy_transform_cache[t_unit] = executor
self._loopy_transform_cache[orig_t_unit] = t_unit
del orig_t_unit
_, result = t_unit(**kwargs) _, result = executor(**kwargs)
return result return result
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment