diff --git a/arraycontext/context.py b/arraycontext/context.py index d296f8f7dd276d80e678fea99cfbf27b71f755af..30f58cb1e2494447f1496c47e7093b8a37955771 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -339,7 +339,7 @@ class ArrayContext(ABC): @abstractmethod def call_loopy(self, - program: "loopy.TranslationUnit", + t_unit: "loopy.TranslationUnit", **kwargs: Any) -> Dict[str, Array]: """Execute the :mod:`loopy` program *program* on the arguments *kwargs*. diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 77b7b49f3983bf40a3f3e59684c7eed5f4638523..f8ba95e3996aebedd57dd7c316ce33f721fbb65a 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + """ .. currentmodule:: arraycontext @@ -30,7 +33,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Dict +from typing import Any import numpy as np @@ -39,6 +42,7 @@ from pytools.tag import ToTagSetConvertible from arraycontext.container.traversal import rec_map_array_container, with_array_context from arraycontext.context import ( + Array, ArrayContext, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, @@ -62,10 +66,12 @@ class NumpyArrayContext(ArrayContext): .. automethod:: __init__ """ + + _loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase] + def __init__(self) -> None: super().__init__() - self._loopy_transform_cache: \ - Dict[lp.TranslationUnit, lp.TranslationUnit] = {} + self._loopy_transform_cache = {} array_types = (NumpyNonObjectArray,) @@ -88,17 +94,18 @@ class NumpyArrayContext(ArrayContext): ) -> NumpyOrContainerOrScalar: return array - def call_loopy(self, t_unit, **kwargs): + def call_loopy( + self, + t_unit: lp.TranslationUnit, **kwargs: Any + ) -> dict[str, Array]: t_unit = t_unit.copy(target=lp.ExecutableCTarget()) try: - t_unit = self._loopy_transform_cache[t_unit] + executor = self._loopy_transform_cache[t_unit] except KeyError: - orig_t_unit = t_unit - t_unit = self.transform_loopy_program(t_unit) - self._loopy_transform_cache[orig_t_unit] = t_unit - del orig_t_unit + executor = self.transform_loopy_program(t_unit).executor() + self._loopy_transform_cache[t_unit] = executor - _, result = t_unit(**kwargs) + _, result = executor(**kwargs) return result