From 83428f328e9ef433f9422809562d82e6c52d8819 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Wed, 19 Jul 2017 01:04:27 -0500
Subject: [PATCH] Be less strict about data types in tuples / reductions.

---
 loopy/library/reduction.py |  6 ++++--
 loopy/target/c/__init__.py | 15 ++++-----------
 test/test_scan.py          |  1 -
 3 files changed, 8 insertions(+), 14 deletions(-)

diff --git a/loopy/library/reduction.py b/loopy/library/reduction.py
index bd085b7e8..3c5f4a142 100644
--- a/loopy/library/reduction.py
+++ b/loopy/library/reduction.py
@@ -123,7 +123,8 @@ class ScalarReductionOperation(ReductionOperation):
 
 class SumReductionOperation(ScalarReductionOperation):
     def neutral_element(self, dtype):
-        return dtype.numpy_dtype.type(0)
+        # FIXME: Document that we always use an int here.
+        return 0
 
     def __call__(self, dtype, operand1, operand2):
         return operand1 + operand2
@@ -131,7 +132,8 @@ class SumReductionOperation(ScalarReductionOperation):
 
 class ProductReductionOperation(ScalarReductionOperation):
     def neutral_element(self, dtype):
-        return dtype.numpy_dtype.type(1)
+        # FIXME: Document that we always use an int here.
+        return 1
 
     def __call__(self, dtype, operand1, operand2):
         return operand1 * operand2
diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py
index ed1ba1ce9..e9457233f 100644
--- a/loopy/target/c/__init__.py
+++ b/loopy/target/c/__init__.py
@@ -651,18 +651,11 @@ class CASTBuilder(ASTBuilderBase):
     def emit_tuple_assignment(self, codegen_state, insn):
         ecm = codegen_state.expression_to_code_mapper
 
-        parameters = insn.expression.parameters
-        parameter_dtypes = tuple(ecm.infer_type(par) for par in parameters)
-
         from cgen import Assign, block_if_necessary
         assignments = []
 
-        for i, (assignee, tgt_dtype) in enumerate(
-                zip(insn.assignees, parameter_dtypes)):
-            if tgt_dtype != ecm.infer_type(assignee):
-                raise LoopyError("type mismatch in %d'th (0-based) left-hand "
-                        "side of instruction '%s'" % (i, insn.id))
-
+        for i, (assignee, parameter) in enumerate(
+                zip(insn.assignees, insn.expression.parameters)):
             lhs_code = ecm(assignee, prec=PREC_NONE, type_context=None)
             assignee_var_name = insn.assignee_var_names()[i]
             lhs_var = codegen_state.kernel.get_var_descriptor(assignee_var_name)
@@ -671,8 +664,8 @@ class CASTBuilder(ASTBuilderBase):
             from loopy.expression import dtype_to_type_context
             rhs_type_context = dtype_to_type_context(
                     codegen_state.kernel.target, lhs_dtype)
-            rhs_code = ecm(parameters[i], prec=PREC_NONE,
-                           type_context=rhs_type_context, needed_dtype=lhs_dtype)
+            rhs_code = ecm(parameter, prec=PREC_NONE,
+                    type_context=rhs_type_context, needed_dtype=lhs_dtype)
 
             assignments.append(Assign(lhs_code, rhs_code))
 
diff --git a/test/test_scan.py b/test/test_scan.py
index c225c2c1c..08754819c 100644
--- a/test/test_scan.py
+++ b/test/test_scan.py
@@ -182,7 +182,6 @@ def test_nested_scan(ctx_factory, i_tag, j_tag):
     knl = lp.fix_parameters(knl, n=10)
     knl = lp.tag_inames(knl, dict(i=i_tag, j=j_tag))
 
-    knl = lp.add_dtypes(knl, dict(tmp=int))
     knl = lp.realize_reduction(knl, force_scan=True)
 
     print(knl)
-- 
GitLab