Skip to content
Snippets Groups Projects
Commit da493b5d authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Fix parallel multi-output reduction (Fixes #31 on Gitlab)

parent 3ddc668d
No related branches found
No related tags found
No related merge requests found
......@@ -417,6 +417,9 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
# }}}
neutral_var_names = [
var_name_gen("neutral_"+red_iname)
for i in range(nresults)]
acc_var_names = [
var_name_gen("acc_"+red_iname)
for i in range(nresults)]
......@@ -429,6 +432,12 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
shape=outer_local_iname_sizes + (size,),
dtype=dtype,
scope=temp_var_scope.LOCAL)
for name, dtype in zip(neutral_var_names, reduction_dtypes):
new_temporary_variables[name] = TemporaryVariable(
name=name,
shape=(),
dtype=dtype,
scope=temp_var_scope.PRIVATE)
base_iname_deps = outer_insn_inames - frozenset(expr.inames)
......@@ -446,6 +455,16 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
depends_on=frozenset())
generated_insns.append(init_insn)
init_neutral_id = insn_id_gen("%s_%s_init_neutral" % (insn.id, red_iname))
init_neutral_insn = make_assignment(
id=init_neutral_id,
assignees=tuple(var(nvn) for nvn in neutral_var_names),
expression=neutral,
within_inames=base_iname_deps | frozenset([base_exec_iname]),
within_inames_is_final=insn.within_inames_is_final,
depends_on=frozenset())
generated_insns.append(init_neutral_insn)
transfer_id = insn_id_gen("%s_%s_transfer" % (insn.id, red_iname))
transfer_insn = make_assignment(
id=transfer_id,
......@@ -453,12 +472,14 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
acc_var[outer_local_iname_vars + (var(red_iname),)]
for acc_var in acc_vars),
expression=expr.operation(
arg_dtype, neutral, expr.expr, expr.inames),
arg_dtype,
tuple(var(nvn) for nvn in neutral_var_names),
expr.expr, expr.inames),
within_inames=(
(outer_insn_inames - frozenset(expr.inames))
| frozenset([red_iname])),
within_inames_is_final=insn.within_inames_is_final,
depends_on=frozenset([init_id]) | insn.depends_on,
depends_on=frozenset([init_id, init_neutral_id]) | insn.depends_on,
no_sync_with=frozenset([(init_id, "any")]))
generated_insns.append(transfer_insn)
......@@ -493,15 +514,15 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
for acc_var in acc_vars),
expression=expr.operation(
arg_dtype,
_strip_if_scalar([
_strip_if_scalar(tuple(
acc_var[
outer_local_iname_vars + (var(stage_exec_iname),)]
for acc_var in acc_vars]),
_strip_if_scalar([
for acc_var in acc_vars)),
_strip_if_scalar(tuple(
acc_var[
outer_local_iname_vars + (
var(stage_exec_iname) + new_size,)]
for acc_var in acc_vars]),
for acc_var in acc_vars)),
expr.inames),
within_inames=(
base_iname_deps | frozenset([stage_exec_iname])),
......
......@@ -393,6 +393,18 @@ def test_double_sum_made_unique(ctx_factory):
assert b.get() == ref
def test_parallel_multi_output_reduction():
knl = lp.make_kernel(
"{[i]: 0<=i<128}",
"""
max_val, max_indices = argmax(i, fabs(a[i]))
""")
knl = lp.tag_inames(knl, dict(i="l.0"))
knl = lp.realize_reduction(knl)
print(knl)
# TODO: Add functional test
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment