From 3a377be0e28013f7d9fbfe259b68af78caed0a52 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Sat, 12 Jun 2021 17:35:40 -0500 Subject: [PATCH] filter out arguments that aren't needed to the entrypoint --- arraycontext/impl/pytato.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index a8748ca..76a1143 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -172,7 +172,7 @@ class PytatoCompiledOperator: for pos, arg in enumerate(args): if isinstance(arg, np.number): - input_kwargs_to_loopy[self.input_id_to_name_in_program[pos]] = ( + input_kwargs_to_loopy[self.input_id_to_name_in_program[(pos,)]] = ( arg) elif is_array_container(arg): def _extract_lpy_kwargs(keys, ary): @@ -196,6 +196,17 @@ class PytatoCompiledOperator: else: raise NotImplementedError(type(arg)) + # {{{ the generated program might not have depended on some of the + # inputs => do not pass those to the loopy kernel + + input_kwargs_to_loopy = {arg_name: arg + for arg_name, arg in input_kwargs_to_loopy.items() + if arg_name in (self.pytato_program + .program.default_entrypoint + .arg_dict)} + + # }}} + evt, out_dict = self.pytato_program(queue=self.actx.queue, allocator=self.actx.allocator, **input_kwargs_to_loopy) @@ -312,12 +323,12 @@ class PytatoArrayContext(ArrayContext): def to_placeholder(input_like, pos): if isinstance(input_like, np.number): - name = f"_pt_in_{pos}" + name = f"_actx_in_{pos}" input_naming_map[(pos, )] = name return pt.make_placeholder((), input_like.dtype, name) elif is_array_container(input_like): def _rec_to_placeholder(keys, ary): - name = f"_pt_in_{pos}_" + "_".join(str(key) + name = f"_actx_in_{pos}_" + "_".join(str(key) for key in keys) input_naming_map[(pos,) + keys] = name return pt.make_placeholder(ary.shape, ary.dtype, -- GitLab