diff --git a/arraycontext/context.py b/arraycontext/context.py index ecfde931bcc54646c30a6513cc7b5f56edb28c7c..58c4299e9df667bb9c2f710d0108a034ec116592 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -201,35 +201,6 @@ class ArrayContext(ABC): array understood by the context. """ - @memoize_method - def _get_scalar_func_loopy_program(self, c_name, nargs, naxes): - from pymbolic import var - - var_names = ["i%d" % i for i in range(naxes)] - size_names = ["n%d" % i for i in range(naxes)] - subscript = tuple(var(vname) for vname in var_names) - from islpy import make_zero_and_vars - v = make_zero_and_vars(var_names, params=size_names) - domain = v[0].domain() - for vname, sname in zip(var_names, size_names): - domain = domain & v[0].le_set(v[vname]) & v[vname].lt_set(v[sname]) - - domain_bset, = domain.get_basic_sets() - - import loopy as lp - from .loopy import make_loopy_program - from arraycontext.transform_metadata import ElementwiseMapKernelTag - return make_loopy_program( - [domain_bset], - [ - lp.Assignment( - var("out")[subscript], - var(c_name)(*[ - var("inp%d" % i)[subscript] for i in range(nargs)])) - ], - name="actx_special_%s" % c_name, - tags=(ElementwiseMapKernelTag(),)) - @abstractmethod def freeze(self, array): """Return a version of the context-defined array *array* that is diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 1f208f3dd02f878e7be1523f8586c2caa3f1ddcd..0c1309cb581fa8412be51625cc322baa9e0c2e8c 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -27,6 +27,44 @@ import numpy as np from arraycontext.container import is_array_container, serialize_container from arraycontext.container.traversal import ( rec_map_array_container, multimapped_over_array_containers) +from pytools import memoize_in + + +# {{{ _get_scalar_func_loopy_program + +def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes): + @memoize_in(actx, _get_scalar_func_loopy_program) + def get(c_name, nargs, naxes): + from pymbolic import var + + var_names = ["i%d" % i for i in range(naxes)] + size_names = ["n%d" % i for i in range(naxes)] + subscript = tuple(var(vname) for vname in var_names) + from islpy import make_zero_and_vars + v = make_zero_and_vars(var_names, params=size_names) + domain = v[0].domain() + for vname, sname in zip(var_names, size_names): + domain = domain & v[0].le_set(v[vname]) & v[vname].lt_set(v[sname]) + + domain_bset, = domain.get_basic_sets() + + import loopy as lp + from .loopy import make_loopy_program + from arraycontext.transform_metadata import ElementwiseMapKernelTag + return make_loopy_program( + [domain_bset], + [ + lp.Assignment( + var("out")[subscript], + var(c_name)(*[ + var("inp%d" % i)[subscript] for i in range(nargs)])) + ], + name="actx_special_%s" % c_name, + tags=(ElementwiseMapKernelTag(),)) + + return get(c_name, nargs, naxes) + +# }}} # {{{ BaseFakeNumpyNamespace @@ -112,7 +150,7 @@ class BaseFakeNumpyNamespace: actx = self._array_context # FIXME: Maybe involve loopy type inference? result = actx.empty(args[0].shape, args[0].dtype) - prg = actx._get_scalar_func_loopy_program( + prg = _get_scalar_func_loopy_program(actx, c_name, nargs=len(args), naxes=len(args[0].shape)) actx.call_loopy(prg, out=result, **{"inp%d" % i: arg for i, arg in enumerate(args)})