From 696e65310164d01f4b62cab98bc38c0c24e5249b Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Tue, 29 Jun 2021 06:45:58 -0500 Subject: [PATCH] Avoids closure in LazilyCompilingFunctionCaller.__call__ --- arraycontext/impl/pytato/compile.py | 43 ++++++++++++++++------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 10d133f..2098208 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -128,6 +128,26 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...] return pmap(arg_id_to_arg), pmap(arg_id_to_descr) +def _get_f_placeholder_args(arg, iarg, arg_id_to_name): + """ + Helper for :class:`LazilyCompilingFunctionCaller.__call__`. Returns the + placeholder version of an argument to + :attr:`LazilyCompilingFunctionCaller.f`. + """ + if np.isscalar(arg): + name = arg_id_to_name[(iarg,)] + return pt.make_placeholder((), np.dtype(arg), name) + elif is_array_container(arg): + def _rec_to_placeholder(keys, ary): + name = arg_id_to_name[(iarg,) + keys] + return pt.make_placeholder(ary.shape, ary.dtype, + name) + return rec_keyed_map_array_container(_rec_to_placeholder, + arg) + else: + raise NotImplementedError(type(arg)) + + @dataclass class LazilyCompilingFunctionCaller: """ @@ -173,26 +193,11 @@ class LazilyCompilingFunctionCaller: output_naming_map = {} # input_naming_map: argument id to placeholder name in the generated # pytato DAG. - input_naming_map = {} - - def to_placeholder(arg, pos): - if np.isscalar(arg): - name = f"_actx_in_{pos}" - input_naming_map[(pos, )] = name - return pt.make_placeholder((), np.dtype(arg), name) - elif is_array_container(arg): - def _rec_to_placeholder(keys, ary): - name = (f"_actx_in_{pos}_" - + _ary_container_key_stringifier(keys)) - input_naming_map[(pos,) + keys] = name - return pt.make_placeholder(ary.shape, ary.dtype, - name) - return rec_keyed_map_array_container(_rec_to_placeholder, - arg) - else: - raise NotImplementedError(type(arg)) + input_naming_map = { + arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}" + for arg_id in arg_id_to_arg} - outputs = self.f(*[to_placeholder(arg, iarg) + outputs = self.f(*[_get_f_placeholder_args(arg, iarg, input_naming_map) for iarg, arg in enumerate(args)]) if not is_array_container(outputs): -- GitLab