From dd154ce56f9acc93976a5b4fae656ecf674dbd26 Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Fri, 7 Apr 2017 19:38:33 -0500 Subject: [PATCH] Fix type inference. --- loopy/type_inference.py | 30 +++++++++++++++--------------- test/test_loopy.py | 7 ++++--- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/loopy/type_inference.py b/loopy/type_inference.py index b6aa5d1ad..bd5a230dc 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -362,28 +362,28 @@ class TypeInferenceMapper(CombineMapper): if isinstance(expr.exprs, tuple): rec_results = [self.rec(sub_expr) for sub_expr in expr.exprs] + if return_tuple: + from itertools import product + rec_results = product(*rec_results) + else: + rec_results = rec_results[0] elif isinstance(expr.exprs, Reduction): - rec_results = [self.rec(expr.exprs, return_tuple=True)] + rec_results = self.rec(expr.exprs, return_tuple=return_tuple) elif isinstance(expr.exprs, Call): - rec_results = [self.map_call(expr.exprs, return_tuple=return_tuple)] + rec_results = self.map_call(expr.exprs, return_tuple=return_tuple) else: raise LoopyError("unknown reduction type: '%s'" % type(expr.exprs).__name__) - if any(len(rec_result) == 0 for rec_result in rec_results): - return [] - - if return_tuple: - from itertools import product - return [expr.operation.result_dtypes(self.kernel, *product_element) - for product_element in product(*rec_results)] - - if len(rec_results) != 1: - raise LoopyError("reductions with more or fewer than one " - "return value may only be used in direct assignments") + if not return_tuple: + if any(isinstance(rec_result, tuple) for rec_result in rec_results): + raise LoopyError("reductions with more or fewer than one " + "return value may only be used in direct assignments") + return [expr.operation.result_dtypes(self.kernel, rec_result)[0] + for rec_result in rec_results] - return [expr.operation.result_dtypes(self.kernel, rec_result)[0] - for rec_result in rec_results[0]] + return [expr.operation.result_dtypes(self.kernel, *rec_result) + for rec_result in rec_results] # }}} diff --git a/test/test_loopy.py b/test/test_loopy.py index b92161ac7..1cd025c99 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1026,7 +1026,7 @@ def test_within_inames_and_reduction(): from pymbolic.primitives import Subscript, Variable i2 = lp.Assignment("a", - lp.Reduction("sum", "j", Subscript(Variable("phi"), Variable("j"))), + lp.Reduction("sum", "j", (Subscript(Variable("phi"), Variable("j")),)), within_inames=frozenset(), within_inames_is_final=True) @@ -2135,8 +2135,9 @@ def test_multi_argument_reduction_type_inference(): t_inf_mapper = TypeInferenceMapper(knl) - print(t_inf_mapper(expr, return_tuple=True)) - 1/0 + assert ( + t_inf_mapper(expr, return_tuple=True, return_dtype_set=True) + == [(int32, int32)]) if __name__ == "__main__": -- GitLab