diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index a8748cab47a0b2dbb1bd3751812678c472bb5a9d..76a1143680146ae89b2529545b42e19d472b1b8d 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -172,7 +172,7 @@ class PytatoCompiledOperator: for pos, arg in enumerate(args): if isinstance(arg, np.number): - input_kwargs_to_loopy[self.input_id_to_name_in_program[pos]] = ( + input_kwargs_to_loopy[self.input_id_to_name_in_program[(pos,)]] = ( arg) elif is_array_container(arg): def _extract_lpy_kwargs(keys, ary): @@ -196,6 +196,17 @@ class PytatoCompiledOperator: else: raise NotImplementedError(type(arg)) + # {{{ the generated program might not have depended on some of the + # inputs => do not pass those to the loopy kernel + + input_kwargs_to_loopy = {arg_name: arg + for arg_name, arg in input_kwargs_to_loopy.items() + if arg_name in (self.pytato_program + .program.default_entrypoint + .arg_dict)} + + # }}} + evt, out_dict = self.pytato_program(queue=self.actx.queue, allocator=self.actx.allocator, **input_kwargs_to_loopy) @@ -312,12 +323,12 @@ class PytatoArrayContext(ArrayContext): def to_placeholder(input_like, pos): if isinstance(input_like, np.number): - name = f"_pt_in_{pos}" + name = f"_actx_in_{pos}" input_naming_map[(pos, )] = name return pt.make_placeholder((), input_like.dtype, name) elif is_array_container(input_like): def _rec_to_placeholder(keys, ary): - name = f"_pt_in_{pos}_" + "_".join(str(key) + name = f"_actx_in_{pos}_" + "_".join(str(key) for key in keys) input_naming_map[(pos,) + keys] = name return pt.make_placeholder(ary.shape, ary.dtype,