From 7955dc8b6df43680c9c3c72771cf30b1677ff369 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Fri, 5 Apr 2019 16:11:35 -0500 Subject: [PATCH] normalizes the naming of instructions produced during parallel reductions --- loopy/preprocess.py | 11 +++++------ test/test_reduction.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 2afcd3db4..088fbb3ff 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -1177,7 +1177,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, base_iname_deps = outer_insn_inames - frozenset(expr.inames) neutral = expr.operation.neutral_element(*arg_dtypes) - init_id = insn_id_gen("%s_%s_init" % (insn.id, red_iname)) + init_id = "red_init_%s_%s" % (red_iname, insn.id) init_insn = make_assignment( id=init_id, assignees=tuple( @@ -1191,7 +1191,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, ) generated_insns.append(init_insn) - init_neutral_id = insn_id_gen("%s_%s_init_neutral" % (insn.id, red_iname)) + init_neutral_id = "red_init_neutral_%s_%s" % (red_iname, insn.id) init_neutral_insn = make_assignment( id=init_neutral_id, assignees=tuple(var(nvn) for nvn in neutral_var_names), @@ -1228,7 +1228,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, else: reduction_expr = expr.expr - transfer_id = insn_id_gen("%s_%s_transfer" % (insn.id, red_iname)) + transfer_id = "red_transfer_%s_%s" % (red_iname, insn.id) transfer_insn = make_assignment( id=transfer_id, assignees=tuple( @@ -1267,7 +1267,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, domains.append(_make_slab_set(stage_exec_iname, bound-new_size)) new_iname_tags[stage_exec_iname] = kernel.iname_tags(red_iname) - stage_id = insn_id_gen("red_%s_stage_%d" % (red_iname, istage)) + stage_id = "red_stage_%d_%s_%s" % (istage, red_iname, insn.id) stage_insn = make_assignment( id=stage_id, assignees=tuple( @@ -1299,7 +1299,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, istage += 1 new_insn_add_depends_on.add(prev_id) - new_insn_add_no_sync_with.add((prev_id, "any")) new_insn_add_within_inames.add(base_exec_iname or stage_exec_iname) if nresults == 1: @@ -1890,7 +1889,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, new_expr, = new_expressions replacement_insns = [ make_assignment( - id=insn_id_gen(insn.id), + id="red_assign_%s" % (insn.id), depends_on=result_assignment_dep_on, assignees=insn.assignees, expression=new_expr, diff --git a/test/test_reduction.py b/test/test_reduction.py index ef229d5cd..e6a708948 100644 --- a/test/test_reduction.py +++ b/test/test_reduction.py @@ -446,6 +446,45 @@ def test_reduction_with_conditional(): assert code.index("if") < code.index("for") +def test_insn_matching_from_parallel_reduce(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + knl = lp.make_kernel( + "{[i]: 0<=i<4}", + """ + a = simul_reduce(sum, i, 7*i) {id=insn_A} + b = simul_reduce(sum, i, 10*i) {id=insn_B} + """) + + knl = lp.tag_inames(knl, "i:l.0") + knl = lp.realize_reduction(knl) + + # add dependencies to get 3 barriers + knl = lp.add_dependency(knl, "id:red_stage_0_i_insn_A", "id:red_init_i_insn_B " + "or id:red_init_neutral_i_insn_B or id:red_transfer_i_insn_B") + knl = lp.add_dependency(knl, "id:red_stage_0_i_insn_B", "id:red_init_i_insn_A " + "or id:red_init_neutral_i_insn_A or id:red_transfer_i_insn_A") + knl = lp.add_dependency(knl, "id:red_stage_1_i_insn_B", + "id:red_stage_0_i_insn_A") + knl = lp.add_dependency(knl, "id:red_stage_1_i_insn_A", + "id:red_stage_0_i_insn_B") + + knl = lp.tag_inames(knl, "i:l.0") + knl = lp.realize_reduction(knl) + + knl = lp.preprocess_kernel(knl) + knl = lp.get_one_scheduled_kernel(knl) + + assert sum([int(isinstance(item, lp.schedule.Barrier)) for item in + knl.schedule]) == 3 + + evt, (out_a, out_b) = knl(queue) + + assert out_a.get() == 42 + assert out_b.get() == 60 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab