From 6e768a987c73188dccb04cfc29c450c161f6b879 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 30 Jun 2021 12:28:12 -0500
Subject: [PATCH] Make _get_scalar_func_loopy_program a function (not an actx
 method)

---
 arraycontext/context.py    | 29 ---------------------------
 arraycontext/fake_numpy.py | 40 +++++++++++++++++++++++++++++++++++++-
 2 files changed, 39 insertions(+), 30 deletions(-)

diff --git a/arraycontext/context.py b/arraycontext/context.py
index ecfde93..58c4299 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -201,35 +201,6 @@ class ArrayContext(ABC):
             array understood by the context.
         """
 
-    @memoize_method
-    def _get_scalar_func_loopy_program(self, c_name, nargs, naxes):
-        from pymbolic import var
-
-        var_names = ["i%d" % i for i in range(naxes)]
-        size_names = ["n%d" % i for i in range(naxes)]
-        subscript = tuple(var(vname) for vname in var_names)
-        from islpy import make_zero_and_vars
-        v = make_zero_and_vars(var_names, params=size_names)
-        domain = v[0].domain()
-        for vname, sname in zip(var_names, size_names):
-            domain = domain & v[0].le_set(v[vname]) & v[vname].lt_set(v[sname])
-
-        domain_bset, = domain.get_basic_sets()
-
-        import loopy as lp
-        from .loopy import make_loopy_program
-        from arraycontext.transform_metadata import ElementwiseMapKernelTag
-        return make_loopy_program(
-                [domain_bset],
-                [
-                    lp.Assignment(
-                        var("out")[subscript],
-                        var(c_name)(*[
-                            var("inp%d" % i)[subscript] for i in range(nargs)]))
-                    ],
-                name="actx_special_%s" % c_name,
-                tags=(ElementwiseMapKernelTag(),))
-
     @abstractmethod
     def freeze(self, array):
         """Return a version of the context-defined array *array* that is
diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py
index 1f208f3..0c1309c 100644
--- a/arraycontext/fake_numpy.py
+++ b/arraycontext/fake_numpy.py
@@ -27,6 +27,44 @@ import numpy as np
 from arraycontext.container import is_array_container, serialize_container
 from arraycontext.container.traversal import (
         rec_map_array_container, multimapped_over_array_containers)
+from pytools import memoize_in
+
+
+# {{{ _get_scalar_func_loopy_program
+
+def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes):
+    @memoize_in(actx, _get_scalar_func_loopy_program)
+    def get(c_name, nargs, naxes):
+        from pymbolic import var
+
+        var_names = ["i%d" % i for i in range(naxes)]
+        size_names = ["n%d" % i for i in range(naxes)]
+        subscript = tuple(var(vname) for vname in var_names)
+        from islpy import make_zero_and_vars
+        v = make_zero_and_vars(var_names, params=size_names)
+        domain = v[0].domain()
+        for vname, sname in zip(var_names, size_names):
+            domain = domain & v[0].le_set(v[vname]) & v[vname].lt_set(v[sname])
+
+        domain_bset, = domain.get_basic_sets()
+
+        import loopy as lp
+        from .loopy import make_loopy_program
+        from arraycontext.transform_metadata import ElementwiseMapKernelTag
+        return make_loopy_program(
+                [domain_bset],
+                [
+                    lp.Assignment(
+                        var("out")[subscript],
+                        var(c_name)(*[
+                            var("inp%d" % i)[subscript] for i in range(nargs)]))
+                    ],
+                name="actx_special_%s" % c_name,
+                tags=(ElementwiseMapKernelTag(),))
+
+    return get(c_name, nargs, naxes)
+
+# }}}
 
 
 # {{{ BaseFakeNumpyNamespace
@@ -112,7 +150,7 @@ class BaseFakeNumpyNamespace:
             actx = self._array_context
             # FIXME: Maybe involve loopy type inference?
             result = actx.empty(args[0].shape, args[0].dtype)
-            prg = actx._get_scalar_func_loopy_program(
+            prg = _get_scalar_func_loopy_program(actx,
                     c_name, nargs=len(args), naxes=len(args[0].shape))
             actx.call_loopy(prg, out=result,
                     **{"inp%d" % i: arg for i, arg in enumerate(args)})
-- 
GitLab