From 2c1d6bcbb79488ea6d48a6b3e0d691a48560d03a Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Fri, 7 Apr 2017 16:03:35 -0500
Subject: [PATCH] Fix type inferences for reductions with inner calls.

---
 loopy/type_inference.py | 22 +++++++++++-----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/loopy/type_inference.py b/loopy/type_inference.py
index cdba4a5cb..3c77c9882 100644
--- a/loopy/type_inference.py
+++ b/loopy/type_inference.py
@@ -357,25 +357,25 @@ class TypeInferenceMapper(CombineMapper):
             as a tuple type. Otherwise, the number of expressions being reduced over
             must equal 1, and the type of the first expression is returned.
         """
-        rec_results = tuple(self.rec(sub_expr) for sub_expr in expr.exprs)
+        if expr.is_plain_tuple:
+            rec_results = [self.rec(sub_expr) for sub_expr in expr.exprs]
+        else:
+            rec_results = [self.rec(expr.exprs, return_tuple=return_tuple)]
 
         if any(len(rec_result) == 0 for rec_result in rec_results):
             return []
 
         if return_tuple:
             from itertools import product
-            return list(
-                    expr.operation.result_dtypes(self.kernel, *product_element)
-                    for product_element in product(*rec_results))
+            return [expr.operation.result_dtypes(self.kernel, *product_element)
+                    for product_element in product(*rec_results)]
 
-        else:
-            if len(rec_results) != 1:
-                raise LoopyError("reductions with more or fewer than one "
-                        "return value may only be used in direct assignments")
+        if len(rec_results) != 1:
+            raise LoopyError("reductions with more or fewer than one "
+                    "return value may only be used in direct assignments")
 
-            return list(
-                    expr.operation.result_dtypes(self.kernel, rec_result)[0]
-                    for rec_result in rec_results[0])
+        return [expr.operation.result_dtypes(self.kernel, rec_result)[0]
+                for rec_result in rec_results[0]]
 
 # }}}
 
-- 
GitLab