From 02537592fd8d9961c223ec3da64ec1c2aade249b Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Sun, 23 Apr 2017 21:04:10 -0500
Subject: [PATCH] Basic scans working.

---
 loopy/diagnostic.py |   4 +
 loopy/preprocess.py | 109 +++++++----
 test/test_scan.py   | 432 ++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 509 insertions(+), 36 deletions(-)
 create mode 100644 test/test_scan.py

diff --git a/loopy/diagnostic.py b/loopy/diagnostic.py
index 15ab8a1ee..512e4ac86 100644
--- a/loopy/diagnostic.py
+++ b/loopy/diagnostic.py
@@ -103,6 +103,10 @@ class MissingDefinitionError(LoopyError):
 class UnscheduledInstructionError(LoopyError):
     pass
 
+
+class ReductionIsNotTriangularError(LoopyError):
+    pass
+
 # }}}
 
 
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 808273f02..bbc15c03a 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -694,13 +694,6 @@ def _create_domain_for_sweep_tracking(orig_domain,
     return subd
 
 
-def _strip_if_scalar(reference_exprs, expr):
-    if len(reference_exprs) == 1:
-        return expr[0]
-    else:
-        return expr
-
-
 def _hackily_ensure_multi_assignment_return_values_are_scoped_private(kernel):
     """
     Multi assignment function calls are currently lowered into OpenCL so that
@@ -928,6 +921,8 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
 
     var_name_gen = kernel.get_var_name_generator()
     new_temporary_variables = kernel.temporary_variables.copy()
+    inames_added_for_scan = set()
+    inames_to_remove = set()
 
     # {{{ helpers
 
@@ -937,8 +932,44 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
         else:
             return val
 
+    def preprocess_scan_arguments(
+                insn, expr, nresults, scan_iname, track_iname,
+                newly_generated_insn_id_set):
+        """Does iname substitution within scan arguments and returns a set of values
+        suitable to be passed to the binary op. Returns a tuple."""
+
+        if nresults > 1:
+            inner_expr = expr
+
+            # In the case of a multi-argument scan, we need a name for each of
+            # the arguments in order to pass them to the binary op - so we expand
+            # items that are not "plain" tuples here.
+            if not isinstance(inner_expr, tuple):
+                get_args_insn_id = insn_id_gen(
+                        "%s_%s_get" % (insn.id, "_".join(expr.inames)))
+
+                inner_expr = expand_inner_reduction(
+                        id=get_args_insn_id,
+                        expr=inner_expr,
+                        nresults=nresults,
+                        depends_on=insn.depends_on,
+                        within_inames=insn.within_inames | expr.inames,
+                        within_inames_is_final=insn.within_inames_is_final)
+
+                newly_generated_insn_id_set.add(get_args_insn_id)
+
+            updated_inner_exprs = tuple(
+                    replace_var_within_expr(sub_expr, scan_iname, track_iname)
+                    for sub_expr in inner_expr)
+        else:
+            updated_inner_exprs = (
+                    replace_var_within_expr(expr, scan_iname, track_iname),)
+
+        return updated_inner_exprs
+
     def expand_inner_reduction(id, expr, nresults, depends_on, within_inames,
             within_inames_is_final):
+        # FIXME: use make_temporaries
         from pymbolic.primitives import Call
         from loopy.symbolic import Reduction
         assert isinstance(expr, (Call, Reduction))
@@ -988,7 +1019,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
 
         init_insn_depends_on = frozenset()
 
-        global_barrier = temp_kernel.find_most_recent_global_barrier(insn.id)
+        global_barrier = lp.find_most_recent_global_barrier(temp_kernel, insn.id)
 
         if global_barrier is not None:
             init_insn_depends_on |= frozenset([global_barrier])
@@ -1018,17 +1049,20 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
 
         reduction_insn_depends_on = set([init_id])
 
+        # In the case of a multi-argument reduction, we need a name for each of
+        # the arguments in order to pass them to the binary op - so we expand
+        # items that are not "plain" tuples here.
         if nresults > 1 and not isinstance(expr.expr, tuple):
             get_args_insn_id = insn_id_gen(
                     "%s_%s_get" % (insn.id, "_".join(expr.inames)))
 
             reduction_expr = expand_inner_reduction(
-                id=get_args_insn_id,
-                expr=expr.expr,
-                nresults=nresults,
-                depends_on=insn.depends_on,
-                within_inames=update_insn_iname_deps,
-                within_inames_is_final=insn.within_inames_is_final)
+                    id=get_args_insn_id,
+                    expr=expr.expr,
+                    nresults=nresults,
+                    depends_on=insn.depends_on,
+                    within_inames=update_insn_iname_deps,
+                    within_inames_is_final=insn.within_inames_is_final)
 
             reduction_insn_depends_on.add(get_args_insn_id)
         else:
@@ -1165,6 +1199,9 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
 
         transfer_depends_on = set([init_neutral_id, init_id])
 
+        # In the case of a multi-argument reduction, we need a name for each of
+        # the arguments in order to pass them to the binary op - so we expand
+        # items that are not "plain" tuples here.
         if nresults > 1 and not isinstance(expr.expr, tuple):
             get_args_insn_id = insn_id_gen(
                     "%s_%s_get" % (insn.id, red_iname))
@@ -1192,9 +1229,9 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
                 expression=expr.operation(
                     arg_dtypes,
                     _strip_if_scalar(
-                        expr.exprs,
+                        neutral_var_names,
                         tuple(var(nvn) for nvn in neutral_var_names)),
-                    _strip_if_scalar(expr.exprs, expr.exprs)),
+                    reduction_expr),
                 within_inames=(
                     (outer_insn_inames - frozenset(expr.inames))
                     | frozenset([red_iname])),
@@ -1263,6 +1300,8 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
 
     # {{{ utils (stateful)
 
+    from pytools import memoize
+
     @memoize
     def get_or_add_sweep_tracking_iname_and_domain(
             scan_iname, sweep_iname, sweep_min_value, scan_min_value, stride,
@@ -1345,7 +1384,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
 
         init_insn_depends_on = frozenset()
 
-        global_barrier = temp_kernel.find_most_recent_global_barrier(insn.id)
+        global_barrier = lp.find_most_recent_global_barrier(temp_kernel, insn.id)
 
         if global_barrier is not None:
             init_insn_depends_on |= frozenset([global_barrier])
@@ -1361,9 +1400,11 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
 
         generated_insns.append(init_insn)
 
-        updated_inner_exprs = tuple(
-                replace_var_within_expr(sub_expr, scan_iname, track_iname)
-                for sub_expr in expr.exprs)
+        update_insn_depends_on = set([init_insn.id]) | insn.depends_on
+
+        updated_inner_exprs = (
+                preprocess_scan_arguments(insn, expr.expr, nresults,
+                    scan_iname, track_iname, update_insn_depends_on))
 
         update_id = insn_id_gen(
                 based_on="%s_%s_update" % (insn.id, "_".join(expr.inames)))
@@ -1379,7 +1420,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
                     arg_dtypes,
                     _strip_if_scalar(acc_vars, acc_vars),
                     _strip_if_scalar(acc_vars, updated_inner_exprs)),
-                depends_on=frozenset([init_insn.id]) | insn.depends_on,
+                depends_on=frozenset(update_insn_depends_on),
                 within_inames=update_insn_iname_deps,
                 no_sync_with=insn.no_sync_with,
                 within_inames_is_final=insn.within_inames_is_final)
@@ -1474,7 +1515,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
 
         init_insn_depends_on = insn.depends_on
 
-        global_barrier = temp_kernel.find_most_recent_global_barrier(insn.id)
+        global_barrier = lp.find_most_recent_global_barrier(temp_kernel, insn.id)
 
         if global_barrier is not None:
             init_insn_depends_on |= frozenset([global_barrier])
@@ -1491,9 +1532,11 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
                 depends_on=init_insn_depends_on)
         generated_insns.append(init_insn)
 
-        updated_inner_exprs = tuple(
-                replace_var_within_expr(sub_expr, scan_iname, track_iname)
-                for sub_expr in expr.exprs)
+        transfer_insn_depends_on = set([init_insn.id]) | insn.depends_on
+
+        updated_inner_exprs = (
+                preprocess_scan_arguments(insn, expr.expr, nresults,
+                    scan_iname, track_iname, transfer_insn_depends_on))
 
         from loopy.symbolic import Reduction
 
@@ -1510,22 +1553,16 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
                 expression=Reduction(
                     operation=expr.operation,
                     inames=(track_iname,),
-                    exprs=updated_inner_exprs,
+                    expr=_strip_if_scalar(acc_vars, updated_inner_exprs),
                     allow_simultaneous=False,
                     ),
                 within_inames=outer_insn_inames - frozenset(expr.inames),
                 within_inames_is_final=insn.within_inames_is_final,
-                depends_on=frozenset([init_id]) | insn.depends_on,
+                depends_on=frozenset(transfer_insn_depends_on),
                 no_sync_with=frozenset([(init_id, "any")]) | insn.no_sync_with)
 
         generated_insns.append(transfer_insn)
 
-        def _strip_if_scalar(c):
-            if len(acc_vars) == 1:
-                return c[0]
-            else:
-                return c
-
         prev_id = transfer_id
 
         istage = 0
@@ -1574,8 +1611,8 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
                         for acc_var in acc_vars),
                     expression=expr.operation(
                         arg_dtypes,
-                        _strip_if_scalar(read_vars),
-                        _strip_if_scalar(tuple(
+                        _strip_if_scalar(acc_vars, read_vars),
+                        _strip_if_scalar(acc_vars, tuple(
                             acc_var[
                                 outer_local_iname_vars + (var(stage_exec_iname),)]
                             for acc_var in acc_vars))
@@ -1631,7 +1668,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
         n_nonlocal_par = len(iname_classes.nonlocal_parallel)
 
         really_force_scan = force_scan and (
-            len(expr.inames) != 1 or expr.inames[0] not in inames_added_for_scan)
+                len(expr.inames) != 1 or expr.inames[0] not in inames_added_for_scan)
 
         def _error_if_force_scan_on(cls, msg):
             if really_force_scan:
diff --git a/test/test_scan.py b/test/test_scan.py
new file mode 100644
index 000000000..fab7f1b0a
--- /dev/null
+++ b/test/test_scan.py
@@ -0,0 +1,432 @@
+from __future__ import division, absolute_import, print_function
+
+__copyright__ = """
+Copyright (C) 2012 Andreas Kloeckner
+Copyright (C) 2016, 2017 Matt Wala
+"""
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+import sys
+import numpy as np
+import loopy as lp
+import pyopencl as cl
+import pyopencl.clmath  # noqa
+import pyopencl.clrandom  # noqa
+import pytest
+
+import logging
+logger = logging.getLogger(__name__)
+
+try:
+    import faulthandler
+except ImportError:
+    pass
+else:
+    faulthandler.enable()
+
+from pyopencl.tools import pytest_generate_tests_for_pyopencl \
+        as pytest_generate_tests
+
+__all__ = [
+        "pytest_generate_tests",
+        "cl"  # 'cl.create_some_context'
+        ]
+
+
+# More things to test.
+# - scan(a) + scan(b)
+# - test for badly tagged inames
+
+@pytest.mark.parametrize("n", [1, 2, 3, 16])
+@pytest.mark.parametrize("stride", [1, 2])
+def test_sequential_scan(ctx_factory, n, stride):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl = lp.make_kernel(
+        "[n] -> {[i,j]: 0<=i<n and 0<=j<=%d*i}" % stride,
+        """
+        a[i] = sum(j, j**2)
+        """
+        )
+
+    knl = lp.fix_parameters(knl, n=n)
+    knl = lp.realize_reduction(knl, force_scan=True)
+
+    evt, (a,) = knl(queue)
+
+    assert (a.get() == np.cumsum(np.arange(stride*n)**2)[::stride]).all()
+
+
+@pytest.mark.parametrize("sweep_lbound, scan_lbound", [
+    (4, 0),
+    (3, 1),
+    (2, 2),
+    (1, 3),
+    (0, 4),
+    (5, -1),
+    ])
+def test_scan_with_different_lower_bound_from_sweep(
+        ctx_factory, sweep_lbound, scan_lbound):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl = lp.make_kernel(
+        "[n, sweep_lbound, scan_lbound] -> "
+        "{[i,j]: sweep_lbound<=i<n+sweep_lbound "
+        "and scan_lbound<=j<=2*(i-sweep_lbound)+scan_lbound}",
+        """
+        out[i-sweep_lbound] = sum(j, j**2)
+        """
+        )
+
+    n = 10
+
+    knl = lp.fix_parameters(knl, sweep_lbound=sweep_lbound, scan_lbound=scan_lbound)
+    knl = lp.realize_reduction(knl, force_scan=True)
+    evt, (out,) = knl(queue, n=n)
+
+    assert (out.get()
+            == np.cumsum(np.arange(scan_lbound, 2*n+scan_lbound)**2)[::2]).all()
+
+
+def test_automatic_scan_detection():
+    knl = lp.make_kernel(
+        [
+            "[n] -> {[i]: 0<=i<n}",
+            "{[j]: 0<=j<=2*i}"
+        ],
+        """
+        a[i] = sum(j, j**2)
+        """
+        )
+
+    cgr = lp.generate_code_v2(knl)
+    assert "scan" in cgr.device_code()
+
+
+def test_selective_scan_realization():
+    pass
+
+
+def test_force_outer_iname_for_scan():
+    knl = lp.make_kernel(
+        "[n] -> {[i,j,k]: 0<=k<n and 0<=i<=k and 0<=j<=i}",
+        "out[i] = product(j, a[j]) {inames=i:k}")
+
+    knl = lp.add_dtypes(knl, dict(a=np.float32))
+
+    # TODO: Maybe this deserves to work?
+    with pytest.raises(lp.diagnostic.ReductionIsNotTriangularError):
+        lp.realize_reduction(knl, force_scan=True)
+
+    knl = lp.realize_reduction(knl, force_scan=True, force_outer_iname_for_scan="i")
+
+
+def test_dependent_domain_scan(ctx_factory):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl = lp.make_kernel(
+        [
+            "[n] -> {[i]: 0<=i<n}",
+            "{[j]: 0<=j<=2*i}"
+        ],
+        """
+        a[i] = sum(j, j**2) {id=scan}
+        """
+        )
+    knl = lp.realize_reduction(knl, force_scan=True)
+    evt, (a,) = knl(queue, n=100)
+
+    assert (a.get() == np.cumsum(np.arange(200)**2)[::2]).all()
+
+
+@pytest.mark.parametrize("i_tag, j_tag", [
+    ("for", "for")
+    ])
+def test_nested_scan(ctx_factory, i_tag, j_tag):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl = lp.make_kernel(
+        [
+            "[n] -> {[i]: 0 <= i < n}",
+            "[i] -> {[j]: 0 <= j <= i}",
+            "[i] -> {[k]: 0 <= k <= i}"
+        ],
+        """
+        <>tmp[i] = sum(k, 1)
+        out[i] = sum(j, tmp[j])
+        """)
+
+    knl = lp.fix_parameters(knl, n=10)
+    knl = lp.tag_inames(knl, dict(i=i_tag, j=j_tag))
+
+    knl = lp.realize_reduction(knl, force_scan=True)
+
+    print(knl)
+
+    evt, (out,) = knl(queue)
+
+    print(out)
+
+
+def test_scan_not_triangular():
+    knl = lp.make_kernel(
+        "{[i,j]: 0<=i<100 and 1<=j<=2*i}",
+        """
+        a[i] = sum(j, j**2)
+        """
+        )
+
+    with pytest.raises(lp.diagnostic.ReductionIsNotTriangularError):
+        knl = lp.realize_reduction(knl, force_scan=True)
+
+
+@pytest.mark.parametrize("n", [1, 2, 3, 16, 17])
+def test_local_parallel_scan(ctx_factory, n):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl = lp.make_kernel(
+        "[n] -> {[i,j]: 0<=i<n and 0<=j<=i}",
+        """
+        out[i] = sum(j, a[j]**2)
+        """,
+        "..."
+        )
+
+    knl = lp.fix_parameters(knl, n=n)
+    knl = lp.tag_inames(knl, dict(i="l.0"))
+    knl = lp.realize_reduction(knl, force_scan=True)
+
+    knl = lp.realize_reduction(knl)
+
+    knl = lp.add_dtypes(knl, dict(a=int))
+
+    print(knl)
+
+    evt, (a,) = knl(queue, a=np.arange(n))
+    assert (a == np.cumsum(np.arange(n)**2)).all()
+
+
+def test_local_parallel_scan_with_nonzero_lower_bounds(ctx_factory):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl = lp.make_kernel(
+        "[n] -> {[i,j]: 1<=i<n+1 and 0<=j<=i-1}",
+        """
+        out[i-1] = sum(j, a[j]**2)
+        """,
+        "..."
+        )
+
+    knl = lp.fix_parameters(knl, n=16)
+    knl = lp.tag_inames(knl, dict(i="l.0"))
+    knl = lp.realize_reduction(knl, force_scan=True)
+    knl = lp.realize_reduction(knl)
+
+    knl = lp.add_dtypes(knl, dict(a=int))
+    evt, (out,) = knl(queue, a=np.arange(1, 17))
+
+    assert (out == np.cumsum(np.arange(1, 17)**2)).all()
+
+
+def test_scan_extra_constraints_on_domain():
+    knl = lp.make_kernel(
+        "{[i,j,k]: 0<=i<n and 0<=j<=i and i=k}",
+        "out[i] = sum(j, a[j])")
+
+    with pytest.raises(lp.diagnostic.ReductionIsNotTriangularError):
+        knl = lp.realize_reduction(
+                knl, force_scan=True, force_outer_iname_for_scan="i")
+
+
+@pytest.mark.parametrize("sweep_iname_tag", ["for", "l.1"])
+def test_scan_with_outer_parallel_iname(ctx_factory, sweep_iname_tag):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl = lp.make_kernel(
+        [
+            "{[k]: 0<=k<=1}",
+            "[n] -> {[i,j]: 0<=i<n and 0<=j<=i}"
+        ],
+        "out[k,i] = k + sum(j, j**2)"
+        )
+
+    knl = lp.tag_inames(knl, dict(k="l.0", i=sweep_iname_tag))
+    n = 10
+    knl = lp.fix_parameters(knl, n=n)
+    knl = lp.realize_reduction(knl, force_scan=True)
+
+    evt, (out,) = knl(queue)
+
+    inner = np.cumsum(np.arange(n)**2)
+
+    assert (out.get() == np.array([inner, 1 + inner])).all()
+
+
+@pytest.mark.parametrize("dtype", [
+    np.int32, np.int64, np.float32, np.float64])
+def test_scan_data_types(ctx_factory, dtype):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl = lp.make_kernel(
+            "{[i,j]: 0<=i<n and 0<=j<=i }",
+            "res[i] = reduce(sum, j, a[j])",
+            assumptions="n>=1")
+
+    a = np.random.randn(20).astype(dtype)
+    knl = lp.add_dtypes(knl, dict(a=dtype))
+    knl = lp.realize_reduction(knl, force_scan=True)
+    evt, (res,) = knl(queue, a=a)
+
+    assert np.allclose(res, np.cumsum(a))
+
+
+@pytest.mark.parametrize(("op_name", "np_op"), [
+    ("sum", np.sum),
+    ("product", np.prod),
+    ("min", np.min),
+    ("max", np.max),
+    ])
+def test_scan_library(ctx_factory, op_name, np_op):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    knl = lp.make_kernel(
+            "{[i,j]: 0<=i<n and 0<=j<=i }",
+            "res[i] = reduce(%s, j, a[j])" % op_name,
+            assumptions="n>=1")
+
+    a = np.random.randn(20)
+    knl = lp.add_dtypes(knl, dict(a=np.float))
+    knl = lp.realize_reduction(knl, force_scan=True)
+    evt, (res,) = knl(queue, a=a)
+
+    assert np.allclose(res, np.array(
+            [np_op(a[:i+1]) for i in range(len(a))]))
+
+
+def test_scan_unsupported_tags():
+    pass
+
+
+@pytest.mark.parametrize("i_tag", ["for", "l.0"])
+def test_argmax(ctx_factory, i_tag):
+    logging.basicConfig(level=logging.INFO)
+
+    dtype = np.dtype(np.float32)
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    n = 128
+
+    knl = lp.make_kernel(
+            "{[i,j]: 0<=i<%d and 0<=j<=i}" % n,
+            """
+            max_vals[i], max_indices[i] = argmax(j, fabs(a[j]), j)
+            """)
+
+    knl = lp.tag_inames(knl, dict(i=i_tag))
+    knl = lp.add_and_infer_dtypes(knl, {"a": np.float32})
+    knl = lp.realize_reduction(knl, force_scan=True)
+
+    a = np.random.randn(n).astype(dtype)
+    evt, (max_indices, max_vals) = knl(queue, a=a, out_host=True)
+
+    assert (max_vals == [np.max(np.abs(a)[0:i+1]) for i in range(n)]).all()
+    assert (max_indices == [np.argmax(np.abs(a[0:i+1])) for i in range(n)]).all()
+
+
+def check_segmented_scan_output(arr, segment_boundaries_indices, out):
+    class SegmentGrouper(object):
+
+        def __init__(self):
+            self.seg_idx = 0
+            self.idx = 0
+
+        def __call__(self, key):
+            if self.idx in segment_boundaries_indices:
+                self.seg_idx += 1
+            self.idx += 1
+            return self.seg_idx
+
+    from itertools import groupby
+
+    expected = [np.cumsum(list(group))
+            for _, group in groupby(arr, SegmentGrouper())]
+    actual = [np.array(list(group))
+            for _, group in groupby(out, SegmentGrouper())]
+
+    assert len(expected) == len(actual) == len(segment_boundaries_indices)
+    assert [(e == a).all() for e, a in zip(expected, actual)]
+
+
+@pytest.mark.parametrize("n, segment_boundaries_indices", [
+    (1, (0,)),
+    (2, (0,)),
+    (2, (0, 1)),
+    (3, (0,)),
+    (3, (0, 1)),
+    (3, (0, 2)),
+    (3, (0, 1, 2)),
+    (16, (0, 4, 8, 12))])
+@pytest.mark.parametrize("iname_tag", ("for", "l.0"))
+def test_segmented_scan(ctx_factory, n, segment_boundaries_indices, iname_tag):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    arr = np.ones(n, dtype=np.float32)
+    segment_boundaries = np.zeros(n, dtype=np.int32)
+    segment_boundaries[(segment_boundaries_indices,)] = 1
+
+    knl = lp.make_kernel(
+        "{[i,j]: 0<=i<n and 0<=j<=i}",
+        "out[i], <>_ = reduce(segmented(sum), j, arr[j], segflag[j])",
+        [
+            lp.GlobalArg("arr", np.float32, shape=("n",)),
+            lp.GlobalArg("segflag", np.int32, shape=("n",)),
+            "..."
+        ])
+
+    knl = lp.fix_parameters(knl, n=n)
+    knl = lp.tag_inames(knl, dict(i=iname_tag))
+    knl = lp.realize_reduction(knl, force_scan=True)
+
+    (evt, (out,)) = knl(queue, arr=arr, segflag=segment_boundaries)
+
+    check_segmented_scan_output(arr, segment_boundaries_indices, out)
+
+
+if __name__ == "__main__":
+    if len(sys.argv) > 1:
+        exec(sys.argv[1])
+    else:
+        from py.test.cmdline import main
+        main([__file__])
+
+# vim: foldmethod=marker
-- 
GitLab