From dd154ce56f9acc93976a5b4fae656ecf674dbd26 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Fri, 7 Apr 2017 19:38:33 -0500
Subject: [PATCH] Fix type inference.

---
 loopy/type_inference.py | 30 +++++++++++++++---------------
 test/test_loopy.py      |  7 ++++---
 2 files changed, 19 insertions(+), 18 deletions(-)

diff --git a/loopy/type_inference.py b/loopy/type_inference.py
index b6aa5d1ad..bd5a230dc 100644
--- a/loopy/type_inference.py
+++ b/loopy/type_inference.py
@@ -362,28 +362,28 @@ class TypeInferenceMapper(CombineMapper):
 
         if isinstance(expr.exprs, tuple):
             rec_results = [self.rec(sub_expr) for sub_expr in expr.exprs]
+            if return_tuple:
+                from itertools import product
+                rec_results = product(*rec_results)
+            else:
+                rec_results = rec_results[0]
         elif isinstance(expr.exprs, Reduction):
-            rec_results = [self.rec(expr.exprs, return_tuple=True)]
+            rec_results = self.rec(expr.exprs, return_tuple=return_tuple)
         elif isinstance(expr.exprs, Call):
-            rec_results = [self.map_call(expr.exprs, return_tuple=return_tuple)]
+            rec_results = self.map_call(expr.exprs, return_tuple=return_tuple)
         else:
             raise LoopyError("unknown reduction type: '%s'"
                              % type(expr.exprs).__name__)
 
-        if any(len(rec_result) == 0 for rec_result in rec_results):
-            return []
-
-        if return_tuple:
-            from itertools import product
-            return [expr.operation.result_dtypes(self.kernel, *product_element)
-                    for product_element in product(*rec_results)]
-
-        if len(rec_results) != 1:
-            raise LoopyError("reductions with more or fewer than one "
-                    "return value may only be used in direct assignments")
+        if not return_tuple:
+            if any(isinstance(rec_result, tuple) for rec_result in rec_results):
+                raise LoopyError("reductions with more or fewer than one "
+                                 "return value may only be used in direct assignments")
+            return [expr.operation.result_dtypes(self.kernel, rec_result)[0]
+                for rec_result in rec_results]
 
-        return [expr.operation.result_dtypes(self.kernel, rec_result)[0]
-                for rec_result in rec_results[0]]
+        return [expr.operation.result_dtypes(self.kernel, *rec_result)
+            for rec_result in rec_results]
 
 # }}}
 
diff --git a/test/test_loopy.py b/test/test_loopy.py
index b92161ac7..1cd025c99 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -1026,7 +1026,7 @@ def test_within_inames_and_reduction():
 
     from pymbolic.primitives import Subscript, Variable
     i2 = lp.Assignment("a",
-            lp.Reduction("sum", "j", Subscript(Variable("phi"), Variable("j"))),
+            lp.Reduction("sum", "j", (Subscript(Variable("phi"), Variable("j")),)),
             within_inames=frozenset(),
             within_inames_is_final=True)
 
@@ -2135,8 +2135,9 @@ def test_multi_argument_reduction_type_inference():
 
     t_inf_mapper = TypeInferenceMapper(knl)
 
-    print(t_inf_mapper(expr, return_tuple=True))
-    1/0
+    assert (
+            t_inf_mapper(expr, return_tuple=True, return_dtype_set=True)
+            == [(int32, int32)])
 
 
 if __name__ == "__main__":
-- 
GitLab