From 2c1d6bcbb79488ea6d48a6b3e0d691a48560d03a Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Fri, 7 Apr 2017 16:03:35 -0500 Subject: [PATCH] Fix type inferences for reductions with inner calls. --- loopy/type_inference.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/loopy/type_inference.py b/loopy/type_inference.py index cdba4a5cb..3c77c9882 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -357,25 +357,25 @@ class TypeInferenceMapper(CombineMapper): as a tuple type. Otherwise, the number of expressions being reduced over must equal 1, and the type of the first expression is returned. """ - rec_results = tuple(self.rec(sub_expr) for sub_expr in expr.exprs) + if expr.is_plain_tuple: + rec_results = [self.rec(sub_expr) for sub_expr in expr.exprs] + else: + rec_results = [self.rec(expr.exprs, return_tuple=return_tuple)] if any(len(rec_result) == 0 for rec_result in rec_results): return [] if return_tuple: from itertools import product - return list( - expr.operation.result_dtypes(self.kernel, *product_element) - for product_element in product(*rec_results)) + return [expr.operation.result_dtypes(self.kernel, *product_element) + for product_element in product(*rec_results)] - else: - if len(rec_results) != 1: - raise LoopyError("reductions with more or fewer than one " - "return value may only be used in direct assignments") + if len(rec_results) != 1: + raise LoopyError("reductions with more or fewer than one " + "return value may only be used in direct assignments") - return list( - expr.operation.result_dtypes(self.kernel, rec_result)[0] - for rec_result in rec_results[0]) + return [expr.operation.result_dtypes(self.kernel, rec_result)[0] + for rec_result in rec_results[0]] # }}} -- GitLab