From 5201ec1f5a6c326e77d5346dbd0fc006a8cab7ae Mon Sep 17 00:00:00 2001 From: Matt Wala Date: Sun, 9 Jul 2017 15:43:35 -0500 Subject: [PATCH] Make the tuple generation work. --- loopy/preprocess.py | 8 +++++++- loopy/target/opencl.py | 5 ++++- test/test_target.py | 17 ++++++++++------- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/loopy/preprocess.py b/loopy/preprocess.py index c331ccc82..30968630f 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -331,6 +331,9 @@ def _hackily_ensure_multi_assignment_return_values_are_scoped_private(kernel): # }}} + from loopy.type_inference import TypeInferenceMapper + type_inf_mapper = TypeInferenceMapper(kernel) + from loopy.kernel.instruction import CallInstruction for insn in kernel.instructions: if not isinstance(insn, CallInstruction): @@ -352,6 +355,9 @@ def _hackily_ensure_multi_assignment_return_values_are_scoped_private(kernel): FIRST_POINTER_ASSIGNEE_IDX = 1 # noqa + assignee_dtypes, = type_inf_mapper( + insn.expression, return_tuple=True, return_dtype_set=True) + for assignee_nr, assignee_var_name, assignee in zip( range(FIRST_POINTER_ASSIGNEE_IDX, len(assignees)), assignee_var_names[FIRST_POINTER_ASSIGNEE_IDX:], @@ -383,7 +389,7 @@ def _hackily_ensure_multi_assignment_return_values_are_scoped_private(kernel): new_temporaries[new_assignee_name] = ( TemporaryVariable( name=new_assignee_name, - dtype=lp.auto, + dtype=assignee_dtypes[assignee_nr], scope=temp_var_scope.PRIVATE)) from pymbolic import var diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index 01e56405e..e70acfeab 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -390,10 +390,13 @@ class OpenCLCASTBuilder(CASTBuilder): def preamble_generators(self): from loopy.library.reduction import reduction_preamble_generator + from loopy.library.tuple import tuple_preamble_generator + return ( super(OpenCLCASTBuilder, self).preamble_generators() + [ opencl_preamble_generator, - reduction_preamble_generator + reduction_preamble_generator, + tuple_preamble_generator ]) # }}} diff --git a/test/test_target.py b/test/test_target.py index 4b09829e1..2c6119552 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -176,17 +176,20 @@ def test_random123(ctx_factory, tp): assert (0 <= out).all() -def test_tuple(): +def test_tuple(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + knl = lp.make_kernel( - "{ [i]: 0 <= i < 10 }", + "{ [i]: 0 = i }", """ - a, b = make_tuple(1, 2) + a, b = make_tuple(1, 2.) """) - print( - lp.generate_code( - lp.get_one_scheduled_kernel( - lp.preprocess_kernel(knl)))[0]) + evt, (a,b) = knl(queue) + + assert a.get() == 1 + assert b.get() == 2. def test_clamp(ctx_factory): -- GitLab