From 696e65310164d01f4b62cab98bc38c0c24e5249b Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Tue, 29 Jun 2021 06:45:58 -0500
Subject: [PATCH] Avoids closure in LazilyCompilingFunctionCaller.__call__

---
 arraycontext/impl/pytato/compile.py | 43 ++++++++++++++++-------------
 1 file changed, 24 insertions(+), 19 deletions(-)

diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 10d133f..2098208 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -128,6 +128,26 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...]
     return pmap(arg_id_to_arg), pmap(arg_id_to_descr)
 
 
+def _get_f_placeholder_args(arg, iarg, arg_id_to_name):
+    """
+    Helper for :class:`LazilyCompilingFunctionCaller.__call__`. Returns the
+    placeholder version of an argument to
+    :attr:`LazilyCompilingFunctionCaller.f`.
+    """
+    if np.isscalar(arg):
+        name = arg_id_to_name[(iarg,)]
+        return pt.make_placeholder((), np.dtype(arg), name)
+    elif is_array_container(arg):
+        def _rec_to_placeholder(keys, ary):
+            name = arg_id_to_name[(iarg,) + keys]
+            return pt.make_placeholder(ary.shape, ary.dtype,
+                                       name)
+        return rec_keyed_map_array_container(_rec_to_placeholder,
+                                                arg)
+    else:
+        raise NotImplementedError(type(arg))
+
+
 @dataclass
 class LazilyCompilingFunctionCaller:
     """
@@ -173,26 +193,11 @@ class LazilyCompilingFunctionCaller:
         output_naming_map = {}
         # input_naming_map: argument id to placeholder name in the generated
         # pytato DAG.
-        input_naming_map = {}
-
-        def to_placeholder(arg, pos):
-            if np.isscalar(arg):
-                name = f"_actx_in_{pos}"
-                input_naming_map[(pos, )] = name
-                return pt.make_placeholder((), np.dtype(arg), name)
-            elif is_array_container(arg):
-                def _rec_to_placeholder(keys, ary):
-                    name = (f"_actx_in_{pos}_"
-                            + _ary_container_key_stringifier(keys))
-                    input_naming_map[(pos,) + keys] = name
-                    return pt.make_placeholder(ary.shape, ary.dtype,
-                                               name)
-                return rec_keyed_map_array_container(_rec_to_placeholder,
-                                                     arg)
-            else:
-                raise NotImplementedError(type(arg))
+        input_naming_map = {
+            arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}"
+            for arg_id in arg_id_to_arg}
 
-        outputs = self.f(*[to_placeholder(arg, iarg)
+        outputs = self.f(*[_get_f_placeholder_args(arg, iarg, input_naming_map)
                            for iarg, arg in enumerate(args)])
 
         if not is_array_container(outputs):
-- 
GitLab