diff --git a/loopy/preprocess.py b/loopy/preprocess.py index bc4c84524b3ecdd98be3b43d03be09fe3452f2dd..f6bf6ab88b1b306dea62c74dd713e270a53a0cbc 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -2141,14 +2141,20 @@ class UnScopedCallCollector(CombineMapper): map_variable = map_constant map_function_symbol = map_constant + map_tagged_variable = map_constant def check_functions_are_scoped(kernel): """ Checks if all the calls in the instruction expression have been scoped, otherwise indicate to what all calls we await signature. """ + + from loopy.symbolic import SubstitutionRuleExpander + subst_expander = SubstitutionRuleExpander(kernel.substitutions) + for insn in kernel.instructions: - unscoped_calls = UnScopedCallCollector()(insn.expression) + unscoped_calls = UnScopedCallCollector()(subst_expander( + insn.expression)) if unscoped_calls: raise LoopyError("Unknown function '%s' obtained -- register a function" " or a kernel corresponding to it." % set(unscoped_calls).pop()) @@ -2278,6 +2284,7 @@ class ArgDescriptionInferer(CombineMapper): map_variable = map_constant map_function_symbol = map_constant + map_tagged_variable = map_constant def infer_arg_descr(kernel): @@ -2355,6 +2362,7 @@ class ReadyForCodegen(CombineMapper): map_variable = map_constant map_function_symbol = map_constant + map_tagged_variable = map_constant def specializing_incomplete_callables(kernel): diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 831bab5c21ccfccb4d65465eb6dd6335c7b938b5..62de58e76456f0b8e5886a37c1377eafe7aa4ac5 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -1192,12 +1192,12 @@ class FunctionToPrimitiveMapper(IdentityMapper): else: raise TypeError("cse takes two arguments") - elif name in set(["reduce, simul_reduce"]): + elif name in ["reduce", "simul_reduce"]: if len(expr.parameters) >= 3: function, inames = expr.parameters[:2] red_exprs = expr.parameters[2:] - return self._parse_reduction(str(function), inames, + return self._parse_reduction(str(function.name), inames, tuple(self.rec(red_expr) for red_expr in red_exprs), allow_simultaneous=(name == "simul_reduce")) else: