From 19cc672990effff5a7e119a6582b2943e3dda6f7 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 18 Mar 2018 19:44:28 -0500 Subject: [PATCH] Removed the logic error in ArgDescriptorInferer --- loopy/preprocess.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 741f828e2..01eeb5130 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -2166,7 +2166,7 @@ class ArgDescriptionInferer(CombineMapper): def combine(self, values): import operator - return reduce(operator.or_, values, set()) + return reduce(operator.or_, values, frozenset()) def map_call(self, expr, **kwargs): from loopy.kernel.function_interface import ValueArgDescriptor @@ -2200,7 +2200,9 @@ class ArgDescriptionInferer(CombineMapper): combined_arg_id_to_dtype)) # collecting the descriptors for args, kwargs, assignees - return set(((expr, new_scoped_function),)) + return ( + frozenset(((expr, new_scoped_function), )) | + self.combine((self.rec(child) for child in expr.parameters))) def map_call_with_kwargs(self, expr, **kwargs): from loopy.kernel.function_intergace import ValueArgDescriptor @@ -2234,14 +2236,17 @@ class ArgDescriptionInferer(CombineMapper): combined_arg_id_to_descr)) # collecting the descriptors for args, kwargs, assignees - return set(((expr, new_scoped_function),)) + return ( + frozenset(((expr, new_scoped_function), )) | + self.combine((self.rec(child) for child in expr.parameters))) def map_constant(self, expr): - return set() + return frozenset() map_variable = map_constant map_function_symbol = map_constant + def infer_arg_descr(kernel): """ Specializes the kernel functions in way that the functions agree upon shape and dimensions of the arguments too. @@ -2259,8 +2264,8 @@ def infer_arg_descr(kernel): arg_description_modifier(insn.expression, assignees=insn.assignees)) if isinstance(insn, (MultiAssignmentBase, CInstruction)): - pymbolic_calls_to_functions.update(arg_description_modifier( - insn.expression)) + a = arg_description_modifier(insn.expression) + pymbolic_calls_to_functions.update(a) elif isinstance(insn, _DataObliviousInstruction): pass else: -- GitLab