diff --git a/doc/reference.rst b/doc/reference.rst index b415e40afa08ad57188f069d3d2b5a8197a38748..66eb78593e9a5fba6a82beae5bfbc24729301cff 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -201,6 +201,8 @@ Wrangling inames .. autofunction:: set_loop_priority +.. autofunction:: split_reduction + Dealing with Substitution Rules ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/loopy/__init__.py b/loopy/__init__.py index b15f447d0e311043d0a2f0109d572bcdbedb4ffe..3b22e4c978da233aafe64909969cdd8447472b03 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -1091,4 +1091,52 @@ def tag_data_axes(knl, ary_names, dim_tags): # }}} + +# {{{ split_reduction + +class _ReductionSplitter(ExpandingIdentityMapper): + def __init__(self, kernel, within, inames, direction): + ExpandingIdentityMapper.__init__(self, + kernel.substitutions, kernel.get_var_name_generator()) + + self.within = within + self.inames = inames + self.direction = direction + + def map_reduction(self, expr, expn_state): + if self.inames <= set(expr.inames) and self.within(expn_state.stack): + leftover_inames = set(expr.inames) - self.inames + + from loopy.symbolic import Reduction + if self.direction == "in": + return Reduction(expr.operation, tuple(leftover_inames), + Reduction(expr.operation, tuple(self.inames), + self.rec(expr.expr, expn_state))) + elif self.direction == "out": + return Reduction(expr.operation, tuple(self.inames), + Reduction(expr.operation, tuple(leftover_inames), + self.rec(expr.expr, expn_state))) + else: + assert False + else: + return ExpandingIdentityMapper.map_reduction(self, expr, expn_state) + + +def split_reduction(kernel, inames, direction, within=None): + # FIXME document me + if direction not in ["in", "out"]: + raise ValueError("invalid value for 'direction': %s" % direction) + + if isinstance(inames, str): + inames = inames.split(",") + inames = set(inames) + + from loopy.context_matching import parse_stack_match + within = parse_stack_match(within) + + rsplit = _ReductionSplitter(kernel, within, inames, direction) + return rsplit.map_kernel(kernel) + +# }}} + # vim: foldmethod=marker diff --git a/test/test_loopy.py b/test/test_loopy.py index f58c93787b9192087ccc8613bba4e50b53966b51..cb7a9e5b3839630880dfa587600d0569c4f48128 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1250,6 +1250,25 @@ def test_inames_deps_from_write_subscript(ctx_factory): assert "i" in knl.insn_inames("myred") +def test_split_reduction(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel(ctx.devices[0], [ + "{[i,j,k]: 0<=i,j,k<n}", + ], + """ + b = sum((i,j,k), a[i,j,k]) + """, + [ + lp.GlobalArg("box_source_starts,box_source_counts_nonchild,a", + None, shape=None), + "..."]) + + knl = lp.split_reduction(knl, "j,k", "out") + print knl + # FIXME: finish test + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])