From 9a13f08e00bf871591b9482951ac9c9f742c4c7b Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 11 Jul 2013 15:27:51 -0400
Subject: [PATCH] Fix reduction splitting interface

---
 doc/reference.rst  |  4 +++-
 loopy/__init__.py  | 13 +++++++++++--
 test/test_loopy.py |  2 +-
 3 files changed, 15 insertions(+), 4 deletions(-)

diff --git a/doc/reference.rst b/doc/reference.rst
index 66eb78593..529e86677 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -201,7 +201,9 @@ Wrangling inames
 
 .. autofunction:: set_loop_priority
 
-.. autofunction:: split_reduction
+.. autofunction:: split_reduction_inward
+
+.. autofunction:: split_reduction_outward
 
 Dealing with Substitution Rules
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 3b22e4c97..37fe6f320 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -1122,8 +1122,7 @@ class _ReductionSplitter(ExpandingIdentityMapper):
             return ExpandingIdentityMapper.map_reduction(self, expr, expn_state)
 
 
-def split_reduction(kernel, inames, direction, within=None):
-    # FIXME document me
+def _split_reduction(kernel, inames, direction, within=None):
     if direction not in ["in", "out"]:
         raise ValueError("invalid value for 'direction': %s" % direction)
 
@@ -1137,6 +1136,16 @@ def split_reduction(kernel, inames, direction, within=None):
     rsplit = _ReductionSplitter(kernel, within, inames, direction)
     return rsplit.map_kernel(kernel)
 
+
+def split_reduction_inward(kernel, inames, within=None):
+    # FIXME document me
+    _split_reduction(kernel, inames, "in", within)
+
+
+def split_reduction_outward(kernel, inames, within=None):
+    # FIXME document me
+    _split_reduction(kernel, inames, "out", within)
+
 # }}}
 
 # vim: foldmethod=marker
diff --git a/test/test_loopy.py b/test/test_loopy.py
index cb7a9e5b3..455917481 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -1264,7 +1264,7 @@ def test_split_reduction(ctx_factory):
                     None, shape=None),
                 "..."])
 
-    knl = lp.split_reduction(knl, "j,k", "out")
+    knl = lp.split_reduction_outward(knl, "j,k")
     print knl
     # FIXME: finish test
 
-- 
GitLab