diff --git a/loopy/type_inference.py b/loopy/type_inference.py index cdba4a5cb43bf04c94452f34684e2729770f8d09..3c77c988261b63334f3cb8f0f84e2ea69c87901b 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -357,25 +357,25 @@ class TypeInferenceMapper(CombineMapper): as a tuple type. Otherwise, the number of expressions being reduced over must equal 1, and the type of the first expression is returned. """ - rec_results = tuple(self.rec(sub_expr) for sub_expr in expr.exprs) + if expr.is_plain_tuple: + rec_results = [self.rec(sub_expr) for sub_expr in expr.exprs] + else: + rec_results = [self.rec(expr.exprs, return_tuple=return_tuple)] if any(len(rec_result) == 0 for rec_result in rec_results): return [] if return_tuple: from itertools import product - return list( - expr.operation.result_dtypes(self.kernel, *product_element) - for product_element in product(*rec_results)) + return [expr.operation.result_dtypes(self.kernel, *product_element) + for product_element in product(*rec_results)] - else: - if len(rec_results) != 1: - raise LoopyError("reductions with more or fewer than one " - "return value may only be used in direct assignments") + if len(rec_results) != 1: + raise LoopyError("reductions with more or fewer than one " + "return value may only be used in direct assignments") - return list( - expr.operation.result_dtypes(self.kernel, rec_result)[0] - for rec_result in rec_results[0]) + return [expr.operation.result_dtypes(self.kernel, rec_result)[0] + for rec_result in rec_results[0]] # }}}