diff --git a/loopy/preprocess.py b/loopy/preprocess.py index db7792cce55c9b7851850c3059e821fc574c3270..2b6d97c38a12b47e5b4653297c18b24c40ed938b 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -417,6 +417,9 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): # }}} + neutral_var_names = [ + var_name_gen("neutral_"+red_iname) + for i in range(nresults)] acc_var_names = [ var_name_gen("acc_"+red_iname) for i in range(nresults)] @@ -429,6 +432,12 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): shape=outer_local_iname_sizes + (size,), dtype=dtype, scope=temp_var_scope.LOCAL) + for name, dtype in zip(neutral_var_names, reduction_dtypes): + new_temporary_variables[name] = TemporaryVariable( + name=name, + shape=(), + dtype=dtype, + scope=temp_var_scope.PRIVATE) base_iname_deps = outer_insn_inames - frozenset(expr.inames) @@ -446,6 +455,22 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): depends_on=frozenset()) generated_insns.append(init_insn) + def _strip_if_scalar(c): + if len(acc_vars) == 1: + return c[0] + else: + return c + + init_neutral_id = insn_id_gen("%s_%s_init_neutral" % (insn.id, red_iname)) + init_neutral_insn = make_assignment( + id=init_neutral_id, + assignees=tuple(var(nvn) for nvn in neutral_var_names), + expression=neutral, + within_inames=base_iname_deps | frozenset([base_exec_iname]), + within_inames_is_final=insn.within_inames_is_final, + depends_on=frozenset()) + generated_insns.append(init_neutral_insn) + transfer_id = insn_id_gen("%s_%s_transfer" % (insn.id, red_iname)) transfer_insn = make_assignment( id=transfer_id, @@ -453,21 +478,17 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): acc_var[outer_local_iname_vars + (var(red_iname),)] for acc_var in acc_vars), expression=expr.operation( - arg_dtype, neutral, expr.expr, expr.inames), + arg_dtype, + _strip_if_scalar(tuple(var(nvn) for nvn in neutral_var_names)), + expr.expr, expr.inames), within_inames=( (outer_insn_inames - frozenset(expr.inames)) | frozenset([red_iname])), within_inames_is_final=insn.within_inames_is_final, - depends_on=frozenset([init_id]) | insn.depends_on, + depends_on=frozenset([init_id, init_neutral_id]) | insn.depends_on, no_sync_with=frozenset([(init_id, "any")])) generated_insns.append(transfer_insn) - def _strip_if_scalar(c): - if len(acc_vars) == 1: - return c[0] - else: - return c - cur_size = 1 while cur_size < size: cur_size *= 2 @@ -493,15 +514,15 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True): for acc_var in acc_vars), expression=expr.operation( arg_dtype, - _strip_if_scalar([ + _strip_if_scalar(tuple( acc_var[ outer_local_iname_vars + (var(stage_exec_iname),)] - for acc_var in acc_vars]), - _strip_if_scalar([ + for acc_var in acc_vars)), + _strip_if_scalar(tuple( acc_var[ outer_local_iname_vars + ( var(stage_exec_iname) + new_size,)] - for acc_var in acc_vars]), + for acc_var in acc_vars)), expr.inames), within_inames=( base_iname_deps | frozenset([stage_exec_iname])), diff --git a/test/test_reduction.py b/test/test_reduction.py index 820c669da494f4d8863d274120cd5c0c7eb4420f..5887df7a628c46fbf09539fdd48c08aaacd8e409 100644 --- a/test/test_reduction.py +++ b/test/test_reduction.py @@ -393,6 +393,18 @@ def test_double_sum_made_unique(ctx_factory): assert b.get() == ref +def test_parallel_multi_output_reduction(): + knl = lp.make_kernel( + "{[i]: 0<=i<128}", + """ + max_val, max_indices = argmax(i, fabs(a[i])) + """) + knl = lp.tag_inames(knl, dict(i="l.0")) + knl = lp.realize_reduction(knl) + print(knl) + # TODO: Add functional test + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])