From e3bb8f37ddffec0e9c998c2950edb330a54df478 Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Sat, 8 Jul 2017 16:45:31 -0500 Subject: [PATCH] Add a make_tuple() function to loopy. This function does trivial things, but it's there to solve the problem that the reduction neutral element getters are not allowed to store dtypes (#80). The function mangler demands that a function knows its type based on its arguments. For the neutral element getters, this is impossible because they take zero arguments. The simplest fix I can think of is to change a call to neutral_element() to a call to make_tuple(). Currently, the tuple code doesn't work yet due to pickling issues. I think the root cause is somewhere in __hackily_ensure_multi_argument_functions_are_scoped_private(). --- loopy/library/function.py | 3 ++- loopy/type_inference.py | 2 +- test/test_target.py | 13 +++++++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/loopy/library/function.py b/loopy/library/function.py index efa590371..f3d14516c 100644 --- a/loopy/library/function.py +++ b/loopy/library/function.py @@ -25,8 +25,9 @@ THE SOFTWARE. def default_function_mangler(kernel, name, arg_dtypes): from loopy.library.reduction import reduction_function_mangler + from loopy.library.tuple import tuple_function_mangler - manglers = [reduction_function_mangler] + manglers = [reduction_function_mangler, tuple_function_mangler] for mangler in manglers: result = mangler(kernel, name, arg_dtypes) if result is not None: diff --git a/loopy/type_inference.py b/loopy/type_inference.py index 78d817ce7..3fb165ead 100644 --- a/loopy/type_inference.py +++ b/loopy/type_inference.py @@ -442,7 +442,7 @@ def _infer_var_type(kernel, var_name, type_inf_mapper, subst_expander): result_i = comp_dtype_set break - assert found + assert found, var_name if result_i is not None: result.append(result_i) diff --git a/test/test_target.py b/test/test_target.py index b656383e7..4b09829e1 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -176,6 +176,19 @@ def test_random123(ctx_factory, tp): assert (0 <= out).all() +def test_tuple(): + knl = lp.make_kernel( + "{ [i]: 0 <= i < 10 }", + """ + a, b = make_tuple(1, 2) + """) + + print( + lp.generate_code( + lp.get_one_scheduled_kernel( + lp.preprocess_kernel(knl)))[0]) + + def test_clamp(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) -- GitLab