From e3bb8f37ddffec0e9c998c2950edb330a54df478 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Sat, 8 Jul 2017 16:45:31 -0500
Subject: [PATCH] Add a make_tuple() function to loopy.

This function does trivial things, but it's there to solve the problem
that the reduction neutral element getters are not allowed to store
dtypes (#80).

The function mangler demands that a function knows its type based on
its arguments. For the neutral element getters, this is impossible
because they take zero arguments. The simplest fix I can think of is
to change a call to neutral_element() to a call to make_tuple().

Currently, the tuple code doesn't work yet due to pickling issues.  I
think the root cause is somewhere in
__hackily_ensure_multi_argument_functions_are_scoped_private().
---
 loopy/library/function.py |  3 ++-
 loopy/type_inference.py   |  2 +-
 test/test_target.py       | 13 +++++++++++++
 3 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/loopy/library/function.py b/loopy/library/function.py
index efa590371..f3d14516c 100644
--- a/loopy/library/function.py
+++ b/loopy/library/function.py
@@ -25,8 +25,9 @@ THE SOFTWARE.
 
 def default_function_mangler(kernel, name, arg_dtypes):
     from loopy.library.reduction import reduction_function_mangler
+    from loopy.library.tuple import tuple_function_mangler
 
-    manglers = [reduction_function_mangler]
+    manglers = [reduction_function_mangler, tuple_function_mangler]
     for mangler in manglers:
         result = mangler(kernel, name, arg_dtypes)
         if result is not None:
diff --git a/loopy/type_inference.py b/loopy/type_inference.py
index 78d817ce7..3fb165ead 100644
--- a/loopy/type_inference.py
+++ b/loopy/type_inference.py
@@ -442,7 +442,7 @@ def _infer_var_type(kernel, var_name, type_inf_mapper, subst_expander):
                         result_i = comp_dtype_set
                         break
 
-                assert found
+                assert found, var_name
                 if result_i is not None:
                     result.append(result_i)
 
diff --git a/test/test_target.py b/test/test_target.py
index b656383e7..4b09829e1 100644
--- a/test/test_target.py
+++ b/test/test_target.py
@@ -176,6 +176,19 @@ def test_random123(ctx_factory, tp):
     assert (0 <= out).all()
 
 
+def test_tuple():
+    knl = lp.make_kernel(
+            "{ [i]: 0 <= i < 10 }",
+            """
+            a, b = make_tuple(1, 2)
+            """)
+
+    print(
+            lp.generate_code(
+                lp.get_one_scheduled_kernel(
+                    lp.preprocess_kernel(knl)))[0])
+
+
 def test_clamp(ctx_factory):
     ctx = ctx_factory()
     queue = cl.CommandQueue(ctx)
-- 
GitLab