Skip to content
Snippets Groups Projects
Commit 3a377be0 authored by Kaushik Kulkarni's avatar Kaushik Kulkarni
Browse files

filter out arguments that aren't needed to the entrypoint

parent fbb75928
No related branches found
No related tags found
No related merge requests found
......@@ -172,7 +172,7 @@ class PytatoCompiledOperator:
for pos, arg in enumerate(args):
if isinstance(arg, np.number):
input_kwargs_to_loopy[self.input_id_to_name_in_program[pos]] = (
input_kwargs_to_loopy[self.input_id_to_name_in_program[(pos,)]] = (
arg)
elif is_array_container(arg):
def _extract_lpy_kwargs(keys, ary):
......@@ -196,6 +196,17 @@ class PytatoCompiledOperator:
else:
raise NotImplementedError(type(arg))
# {{{ the generated program might not have depended on some of the
# inputs => do not pass those to the loopy kernel
input_kwargs_to_loopy = {arg_name: arg
for arg_name, arg in input_kwargs_to_loopy.items()
if arg_name in (self.pytato_program
.program.default_entrypoint
.arg_dict)}
# }}}
evt, out_dict = self.pytato_program(queue=self.actx.queue,
allocator=self.actx.allocator,
**input_kwargs_to_loopy)
......@@ -312,12 +323,12 @@ class PytatoArrayContext(ArrayContext):
def to_placeholder(input_like, pos):
if isinstance(input_like, np.number):
name = f"_pt_in_{pos}"
name = f"_actx_in_{pos}"
input_naming_map[(pos, )] = name
return pt.make_placeholder((), input_like.dtype, name)
elif is_array_container(input_like):
def _rec_to_placeholder(keys, ary):
name = f"_pt_in_{pos}_" + "_".join(str(key)
name = f"_actx_in_{pos}_" + "_".join(str(key)
for key in keys)
input_naming_map[(pos,) + keys] = name
return pt.make_placeholder(ary.shape, ary.dtype,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment