diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 741f828e263050b4865f1eff85b3870c9eb2ce9b..01eeb513046be661646d440d7f3a5e7d691ae1b6 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -2166,7 +2166,7 @@ class ArgDescriptionInferer(CombineMapper): def combine(self, values): import operator - return reduce(operator.or_, values, set()) + return reduce(operator.or_, values, frozenset()) def map_call(self, expr, **kwargs): from loopy.kernel.function_interface import ValueArgDescriptor @@ -2200,7 +2200,9 @@ class ArgDescriptionInferer(CombineMapper): combined_arg_id_to_dtype)) # collecting the descriptors for args, kwargs, assignees - return set(((expr, new_scoped_function),)) + return ( + frozenset(((expr, new_scoped_function), )) | + self.combine((self.rec(child) for child in expr.parameters))) def map_call_with_kwargs(self, expr, **kwargs): from loopy.kernel.function_intergace import ValueArgDescriptor @@ -2234,14 +2236,17 @@ class ArgDescriptionInferer(CombineMapper): combined_arg_id_to_descr)) # collecting the descriptors for args, kwargs, assignees - return set(((expr, new_scoped_function),)) + return ( + frozenset(((expr, new_scoped_function), )) | + self.combine((self.rec(child) for child in expr.parameters))) def map_constant(self, expr): - return set() + return frozenset() map_variable = map_constant map_function_symbol = map_constant + def infer_arg_descr(kernel): """ Specializes the kernel functions in way that the functions agree upon shape and dimensions of the arguments too. @@ -2259,8 +2264,8 @@ def infer_arg_descr(kernel): arg_description_modifier(insn.expression, assignees=insn.assignees)) if isinstance(insn, (MultiAssignmentBase, CInstruction)): - pymbolic_calls_to_functions.update(arg_description_modifier( - insn.expression)) + a = arg_description_modifier(insn.expression) + pymbolic_calls_to_functions.update(a) elif isinstance(insn, _DataObliviousInstruction): pass else: