diff --git a/loopy/library/function.py b/loopy/library/function.py index 8fcdcd6da7aaf66248905bb60f3d8f096b327173..8338875d0ec9f57dcce702a603293d038a9fbd02 100644 --- a/loopy/library/function.py +++ b/loopy/library/function.py @@ -47,7 +47,8 @@ class MakeTupleCallable(ScalarCallable): class IndexOfCallable(ScalarCallable): def with_types(self, arg_id_to_dtype, kernel, program_callables_info): - new_arg_id_to_dtype = arg_id_to_dtype.copy() + new_arg_id_to_dtype = dict((i, dtype) for i, dtype in + arg_id_to_dtype.items() if dtype is not None) new_arg_id_to_dtype[-1] = kernel.index_dtype return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype), diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 65c91871ad276d5e99c295971ca4ab2522176742..cf956f68f821bb1ac4fe6c2df2cb0defe7d426ac 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -37,7 +37,7 @@ from loopy.kernel.instruction import _DataObliviousInstruction from loopy.program import ProgramCallablesInfo from loopy.symbolic import SubArrayRef, LinearSubscript -from pymbolic.primitives import Variable, Subscript +from pymbolic.primitives import Variable, Subscript, Lookup import logging logger = logging.getLogger(__name__) @@ -308,7 +308,9 @@ class TypeInferenceMapper(CombineMapper): # specializing an already specialized function. for id, dtype in arg_id_to_dtype.items(): - if in_knl_callable.arg_id_to_dtype[id] != arg_id_to_dtype[id]: + if id in in_knl_callable.arg_id_to_dtype and ( + in_knl_callable.arg_id_to_dtype[id] != + arg_id_to_dtype[id]): # {{{ ignoring the the cases when there is a discrepancy # between np.uint and np.int @@ -810,6 +812,9 @@ def infer_unknown_types_for_a_single_kernel(kernel, program_callables_info, def _instruction_missed_during_inference(insn): for assignee in insn.assignees: + if isinstance(assignee, Lookup): + assignee = assignee.aggregate + if isinstance(assignee, Variable): if assignee.name in kernel.arg_dict: if kernel.arg_dict[assignee.name].dtype is None: diff --git a/test/test_loopy.py b/test/test_loopy.py index 5baead8337fc38e1380e0bf1505788ade57c9209..9dc74b94f72347e3b4287e244f06292ce60527b4 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2626,7 +2626,7 @@ def test_fixed_parameters(ctx_factory): def test_parameter_inference(): knl = lp.make_kernel("{[i]: 0 <= i < n and i mod 2 = 0}", "") - assert knl.all_params() == set(["n"]) + assert knl.root_kernel.all_params() == set(["n"]) def test_execution_backend_can_cache_dtypes(ctx_factory):