From 2dcdc33c82a27d22c35fb55611b18344c50543d7 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 11 Jul 2013 14:36:55 -0400
Subject: [PATCH] Add split_reduction

---
 doc/reference.rst  |  2 ++
 loopy/__init__.py  | 48 ++++++++++++++++++++++++++++++++++++++++++++++
 test/test_loopy.py | 19 ++++++++++++++++++
 3 files changed, 69 insertions(+)

diff --git a/doc/reference.rst b/doc/reference.rst
index b415e40af..66eb78593 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -201,6 +201,8 @@ Wrangling inames
 
 .. autofunction:: set_loop_priority
 
+.. autofunction:: split_reduction
+
 Dealing with Substitution Rules
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index b15f447d0..3b22e4c97 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -1091,4 +1091,52 @@ def tag_data_axes(knl, ary_names, dim_tags):
 
 # }}}
 
+
+# {{{ split_reduction
+
+class _ReductionSplitter(ExpandingIdentityMapper):
+    def __init__(self, kernel, within, inames, direction):
+        ExpandingIdentityMapper.__init__(self,
+                kernel.substitutions, kernel.get_var_name_generator())
+
+        self.within = within
+        self.inames = inames
+        self.direction = direction
+
+    def map_reduction(self, expr, expn_state):
+        if self.inames <= set(expr.inames) and self.within(expn_state.stack):
+            leftover_inames = set(expr.inames) - self.inames
+
+            from loopy.symbolic import Reduction
+            if self.direction == "in":
+                return Reduction(expr.operation, tuple(leftover_inames),
+                        Reduction(expr.operation, tuple(self.inames),
+                            self.rec(expr.expr, expn_state)))
+            elif self.direction == "out":
+                return Reduction(expr.operation, tuple(self.inames),
+                        Reduction(expr.operation, tuple(leftover_inames),
+                            self.rec(expr.expr, expn_state)))
+            else:
+                assert False
+        else:
+            return ExpandingIdentityMapper.map_reduction(self, expr, expn_state)
+
+
+def split_reduction(kernel, inames, direction, within=None):
+    # FIXME document me
+    if direction not in ["in", "out"]:
+        raise ValueError("invalid value for 'direction': %s" % direction)
+
+    if isinstance(inames, str):
+        inames = inames.split(",")
+    inames = set(inames)
+
+    from loopy.context_matching import parse_stack_match
+    within = parse_stack_match(within)
+
+    rsplit = _ReductionSplitter(kernel, within, inames, direction)
+    return rsplit.map_kernel(kernel)
+
+# }}}
+
 # vim: foldmethod=marker
diff --git a/test/test_loopy.py b/test/test_loopy.py
index f58c93787..cb7a9e5b3 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -1250,6 +1250,25 @@ def test_inames_deps_from_write_subscript(ctx_factory):
     assert "i" in knl.insn_inames("myred")
 
 
+def test_split_reduction(ctx_factory):
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(ctx.devices[0], [
+            "{[i,j,k]: 0<=i,j,k<n}",
+            ],
+            """
+                b = sum((i,j,k), a[i,j,k])
+                """,
+            [
+                lp.GlobalArg("box_source_starts,box_source_counts_nonchild,a",
+                    None, shape=None),
+                "..."])
+
+    knl = lp.split_reduction(knl, "j,k", "out")
+    print knl
+    # FIXME: finish test
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab