From 3db910efac3737f7e00e476e1c3c6b95604802b5 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Sat, 4 Mar 2017 02:39:21 -0600
Subject: [PATCH] [ci skip] Two level reduction + two level scan, semi-working
 version.

---
 loopy/__init__.py            |   2 +
 loopy/preprocess.py          |  76 +++--
 loopy/transform/data.py      |   4 +-
 loopy/transform/reduction.py | 611 +++++++++++++++++++++++++++++++++++
 test/test_reduction.py       |  28 +-
 test/test_scan.py            | 101 ++++--
 6 files changed, 778 insertions(+), 44 deletions(-)
 create mode 100644 loopy/transform/reduction.py

diff --git a/loopy/__init__.py b/loopy/__init__.py
index 6cbb3362e..a10d94463 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -108,6 +108,8 @@ from loopy.transform.batch import to_batched
 from loopy.transform.parameter import assume, fix_parameters
 from loopy.transform.save import save_and_reload_temporaries
 
+from loopy.transform.reduction import make_two_level_reduction
+
 # }}}
 
 from loopy.type_inference import infer_unknown_types
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 5c9a27805..f139810f1 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -328,26 +328,50 @@ def _add_params_to_domain(domain, param_names):
     return domain
 
 
+def _move_set_to_param_dims_except(domain, except_dims):
+    dim_type = isl.dim_type
+
+    iname_idx = 0
+    for iname in domain.get_var_names(dim_type.set):
+        if iname not in except_dims:
+            domain = domain.move_dims(
+                    dim_type.param, 0,
+                    dim_type.set, iname_idx, 1)
+            iname_idx -= 1
+        iname_idx += 1
+
+    return domain
+
+
 def _check_reduction_is_triangular(kernel, expr, scan_param):
     """Check whether the reduction within `expr` with scan parameters described by
     the structure `scan_param` is triangular. This attempts to verify that the
     domain for the scan and sweep inames is as follows:
 
-    [scan_iname, sweep_iname]:
-        (sweep_min_value
-            <= sweep_iname
-            <= sweep_max_value)
-        and
-        (scan_min_value
-            <= scan_iname
-            <= stride * (sweep_iname - sweep_min_value) + scan_min_value)
+    [other inames] -> {
+        [scan_iname, sweep_iname]:
+            (sweep_min_value
+                <= sweep_iname
+                <= sweep_max_value)
+            and
+            (scan_min_value
+                <= scan_iname
+                <= stride * (sweep_iname - sweep_min_value) + scan_min_value)
+    }
     """
 
     dim_type = isl.dim_type
 
-    domain = kernel.get_inames_domain(
+    orig_domain = kernel.get_inames_domain(
+            (scan_param.sweep_iname, scan_param.scan_iname))
+
+    domain = _move_set_to_param_dims_except(orig_domain,
             (scan_param.sweep_iname, scan_param.scan_iname))
 
+    params_for_gisting = domain.params()
+
+    domain = domain.gist_params(params_for_gisting)
+
     tri_domain = isl.BasicSet.universe(domain.params().space)
 
     sweep_iname = scan_param.sweep_iname
@@ -369,7 +393,7 @@ def _check_reduction_is_triangular(kernel, expr, scan_param):
             + scan_min_value)
 
     # Gist against domain params
-    tri_domain = tri_domain.gist(domain.params())
+    tri_domain = tri_domain.gist_params(params_for_gisting)
 
     # Move sweep and scan inames into the set
     tri_domain = tri_domain.move_dims(
@@ -418,7 +442,7 @@ def _try_infer_scan_candidate_from_expr(kernel, expr, sweep_iname=None):
     try:
         sweep_lower_bound, sweep_upper_bound, scan_lower_bound = (
                 _try_infer_scan_and_sweep_bounds(kernel, scan_iname, sweep_iname))
-    except Exception as e:
+    except ValueError as v:
         raise ValueError("Couldn't determine bounds for scan: %s" % e)
 
     try:
@@ -482,12 +506,18 @@ def _try_infer_sweep_iname(domain, scan_iname, candidate_inames):
 
 
 def _try_infer_scan_and_sweep_bounds(kernel, scan_iname, sweep_iname):
-    sweep_bounds = kernel.get_iname_bounds(sweep_iname)
-    scan_bounds = kernel.get_iname_bounds(scan_iname)
+    # FIXME: use home domain of scan_iname...
+    domain = kernel.get_inames_domain((sweep_iname, scan_iname))
+    domain = _move_set_to_param_dims_except(domain, (sweep_iname, scan_iname))
 
-    return (sweep_bounds.lower_bound_pw_aff,
-            sweep_bounds.upper_bound_pw_aff,
-            scan_bounds.lower_bound_pw_aff)
+    domain = domain.gist_params(domain.params()).project_out_except(
+            (sweep_iname,), (isl.dim_type.param,))
+
+    sweep_lower_bound = domain.dim_min(domain.get_var_dict()[sweep_iname][1])
+    sweep_upper_bound = domain.dim_max(domain.get_var_dict()[sweep_iname][1])
+    scan_lower_bound = domain.dim_min(domain.get_var_dict()[scan_iname][1])
+
+    return (sweep_lower_bound, sweep_upper_bound, scan_lower_bound)
 
 
 def _try_infer_scan_stride(kernel, scan_iname, sweep_iname, sweep_lower_bound):
@@ -499,7 +529,7 @@ def _try_infer_scan_stride(kernel, scan_iname, sweep_iname, sweep_lower_bound):
     dim_type = isl.dim_type
 
     domain = kernel.get_inames_domain((sweep_iname, scan_iname))
-    domain_with_sweep_param = _get_domain_with_iname_as_param(domain, sweep_iname)
+    domain_with_sweep_param = _move_set_to_param_dims_except(domain, (scan_iname,))
 
     scan_iname_idx = domain_with_sweep_param.find_dim_by_name(
             dim_type.set, scan_iname)
@@ -659,11 +689,11 @@ def _infer_arg_dtypes_and_reduction_dtypes(kernel, expr, unknown_types_ok):
 def _hackily_ensure_multi_assignment_return_values_are_scoped_private(kernel):
     """
     Multi assignment function calls are currently lowered into OpenCL so that
-    the function call:
+    the function call::
 
        a, b = segmented_sum(x, y, z, w)
 
-    becomes
+    becomes::
 
        a = segmented_sum_mangled(x, y, z, w, &b).
 
@@ -835,6 +865,12 @@ def _hackily_ensure_multi_assignment_return_values_are_scoped_private(kernel):
     return kernel.copy(temporary_variables=new_temporary_variables,
                        instructions=new_instructions)
 
+
+def _insert_subdomain_into_domain_tree(kernel, domains, subdomain):
+    dependent_inames = frozenset(subdomain.get_var_names(isl.dim_type.param))
+    idx, = kernel.get_leaf_domain_indices(dependent_inames)
+    domains.insert(idx + 1, subdomain)
+
 # }}}
 
 
@@ -1140,7 +1176,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
         new_domain = _create_domain_for_sweep_tracking(domain,
                 tracking_iname, sweep_iname, sweep_min_value, scan_min_value, stride)
 
-        domains.append(new_domain)
+        _insert_subdomain_into_domain_tree(kernel, domains, new_domain)
 
         return tracking_iname, new_domain
 
diff --git a/loopy/transform/data.py b/loopy/transform/data.py
index 4014b8575..c6ff596b0 100644
--- a/loopy/transform/data.py
+++ b/loopy/transform/data.py
@@ -694,12 +694,14 @@ def reduction_arg_to_subst_rule(knl, inames, insn_match=None, subst_rule_name=No
 
     var_name_gen = knl.get_var_name_generator()
 
+    # XXX
     def map_reduction(expr, rec, nresults=1):
         if frozenset(expr.inames) != inames_set:
+            assert len(expr.exprs) == 1
             return type(expr)(
                     operation=expr.operation,
                     inames=expr.inames,
-                    expr=rec(expr.expr),
+                    exprs=(rec(expr.exprs[0]),),
                     allow_simultaneous=expr.allow_simultaneous)
 
         if subst_rule_name is None:
diff --git a/loopy/transform/reduction.py b/loopy/transform/reduction.py
new file mode 100644
index 000000000..1693fb515
--- /dev/null
+++ b/loopy/transform/reduction.py
@@ -0,0 +1,611 @@
+from __future__ import division, absolute_import
+
+__copyright__ = "Copyright (C) 2017 Matt Wala"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+
+from loopy.diagnostic import LoopyError
+import loopy as lp
+
+from loopy.kernel.data import auto, temp_var_scope
+from pytools import memoize_method, Record
+import islpy as isl
+
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+__doc__ = """
+.. currentmodule:: loopy
+
+.. autofunction:: make_two_level_reduction
+.. autofunction:: make_two_level_scan
+.. autofunction:: precompute_scan
+"""
+
+
+def make_two_level_reduction(
+        kernel, insn_id, inner_length,
+        nonlocal_storage_scope=None,
+        nonlocal_tag=None,
+        outer_tag=None,
+        inner_tag=None):
+    """
+    Two level reduction, mediated through a "nonlocal" array.
+
+    This turns a reduction of the form::
+
+         [...] result = reduce(i, f(i))
+
+    into::
+
+         i -> inner + inner_length * outer
+
+         [..., nl] nonlocal[nl] = reduce(inner, f(nl, inner))
+         [...]     result       = reduce(outer, nonlocal[outer])
+    """
+
+    # {{{ sanity checks
+
+    reduction = kernel.id_to_insn[insn_id].expression
+    reduction_iname, = reduction.inames
+
+    # }}}
+
+    # {{{ get stable names for everything
+
+    var_name_gen = kernel.get_var_name_generator()
+    insn_id_gen = kernel.get_instruction_id_generator()
+
+    format_kwargs = {"insn": insn_id, "iname": reduction_iname}
+
+    nonlocal_storage_name = var_name_gen(
+            "{insn}_nonlocal".format(**format_kwargs))
+
+    inner_iname = var_name_gen(
+            "{iname}_inner".format(**format_kwargs))
+    outer_iname = var_name_gen(
+            "{iname}_outer".format(**format_kwargs))
+    nonlocal_iname = var_name_gen(
+            "{iname}_nonlocal".format(**format_kwargs))
+
+    inner_subst = var_name_gen(
+            "{insn}_inner_subst".format(**format_kwargs))
+
+    # }}}
+
+    # First we split this iname. This results in (roughly)
+    #
+    # [...] result = reduce([outer, inner], f(outer, inner))
+    #
+    # FIXME: within
+
+    kernel = lp.split_iname(kernel, reduction_iname, inner_length,
+            outer_iname=outer_iname, inner_iname=inner_iname)
+
+    # Next, we split the reduction inward and then extract a substitution
+    # rule for the reduction. This results in
+    #
+    # subst(outer) := reduce(inner, f(outer, inner))
+    # [...] result = reduce([outer], subst(outer))
+    #
+    # FIXME: within, insn_match...
+
+    kernel = lp.split_reduction_inward(kernel, inner_iname)
+    from loopy.transform.data import reduction_arg_to_subst_rule
+    kernel = reduction_arg_to_subst_rule(kernel, outer_iname,
+                                         subst_rule_name=inner_subst)
+
+    # Next, we precompute the inner iname into its own storage.
+
+    # [...,nl] nonlocal[nl] = reduce(inner, f(nl, inner))
+    # [...] result = reduce([outer], nonlocal[outer])
+
+    kernel = lp.precompute(kernel, inner_subst,
+                           sweep_inames=[outer_iname],
+                           precompute_inames=[nonlocal_iname],
+                           temporary_name=nonlocal_storage_name,
+                           temporary_scope=nonlocal_storage_scope)
+
+    return kernel
+
+
+def _update_instructions(kernel, id_to_new_insn, copy=True):
+    if not isinstance(id_to_new_insn, dict):
+        id_to_new_insn = dict((insn.id, insn) for insn in id_to_new_insn)
+    
+    new_instructions = (
+        list(insn for insn in kernel.instructions
+             if insn.id not in id_to_new_insn)
+        + list(id_to_new_insn.values()))
+
+    if copy:
+        kernel = kernel.copy()
+
+    kernel.instructions = new_instructions
+    return kernel
+
+
+def _make_slab_set(iname, size):
+    # FIXME: stolen from preprocess, should be its own thing...
+    v = isl.make_zero_and_vars([iname])
+    bs, = (
+            v[0].le_set(v[iname])
+            &
+            v[iname].lt_set(v[0] + size)).get_basic_sets()
+    print("ADDING SLAB", bs)
+    return bs
+
+
+def _add_scan_subdomain(
+        kernel, scan_iname, sweep_iname):
+    """
+    Add the following domain to the kernel::
+
+        [sweep_iname] -> {[scan_iname] : 0 <= scan_iname <= sweep_iname }
+    """
+    sp = (
+            isl.Space.set_alloc(isl.DEFAULT_CONTEXT, 1, 1)
+            .set_dim_name(isl.dim_type.param, 0, sweep_iname)
+            .set_dim_name(isl.dim_type.set, 0, scan_iname))
+
+    affs = isl.affs_from_space(sp)
+
+    subd, = (
+            affs[scan_iname].le_set(affs[sweep_iname])
+            &
+            affs[scan_iname].ge_set(affs[0])).get_basic_sets()
+
+    sweep_idx, = kernel.get_leaf_domain_indices((sweep_iname,))
+
+    domains = list(kernel.domains)
+    domains.insert(sweep_idx + 1, subd)
+
+    return kernel.copy(domains=domains)
+
+
+def _expand_subst_within_expression(kernel, expr):
+    from loopy.symbolic import RuleAwareSubstitutionRuleExpander, SubstitutionRuleMappingContext
+    from loopy.match import parse_stack_match
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, kernel.get_var_name_generator())
+    submap = RuleAwareSubstitutionRuleExpander(
+            rule_mapping_context,
+            kernel.substitutions,
+            within=lambda *args: True
+            )
+    return submap(expr, kernel, insn=None)
+
+
+def make_two_level_scan(
+        kernel, insn_id,
+        scan_iname,
+        sweep_iname,
+        inner_length,
+        local_storage_name=None,
+        local_storage_scope=None,
+        local_storage_axes=None,
+        nonlocal_storage_name=None,
+        nonlocal_storage_scope=None,
+        nonlocal_tag=None,
+        outer_local_tag=None,
+        inner_local_tag=None,
+        inner_tag=None,
+        outer_tag=None,
+        inner_local_iname=None,
+        outer_local_iname=None):
+    """
+    Two level scan, mediated through a "local" and "nonlocal" array.
+
+    This turns a scan of the form::
+
+         [...,i] result = reduce(j, f(j))
+
+    into::
+
+         [...,l',l''] <scan into local>
+         [...,l']     nonlocal[0] = 0
+         [...,l']     nonlocal[l'+1] = local[l',-1]
+         [...,nl]     <scan into nonlocal>
+         [...,i',i''] result = nonlocal[i'] + local[i',i'']
+    """
+
+    # {{{ sanity checks
+
+    insn = kernel.id_to_insn[insn_id]
+    scan = insn.expression
+    assert scan.inames[0] == scan_iname
+    assert len(scan.inames) == 1
+
+    # }}}
+
+    # {{{ get stable names for everything
+
+    var_name_gen = kernel.get_var_name_generator()
+    insn_id_gen = kernel.get_instruction_id_generator()
+
+    format_kwargs = {"insn": insn_id, "iname": scan_iname, "sweep": sweep_iname}
+
+    nonlocal_storage_name = var_name_gen(
+            "{insn}_nonlocal".format(**format_kwargs))
+
+    inner_iname = var_name_gen(
+            "{sweep}_inner".format(**format_kwargs))
+    outer_iname = var_name_gen(
+            "{sweep}_outer".format(**format_kwargs))
+    nonlocal_iname = var_name_gen(
+            "{sweep}_nonlocal".format(**format_kwargs))
+
+    if inner_local_iname is None:
+        inner_local_iname = var_name_gen(
+                "{sweep}_inner_local".format(**format_kwargs))
+
+    inner_scan_iname = var_name_gen(
+            "{iname}_inner".format(**format_kwargs))
+
+    outer_scan_iname = var_name_gen(
+            "{iname}_outer".format(**format_kwargs))
+
+    if outer_local_iname is None:
+        outer_local_iname = var_name_gen(
+                "{sweep}_outer_local".format(**format_kwargs))
+
+    subst_name = var_name_gen(
+            "{insn}_inner_subst".format(**format_kwargs))
+
+    local_subst_name = var_name_gen(
+            "{insn}_local_subst".format(**format_kwargs))
+
+    if local_storage_name is None:
+        local_storage_name = var_name_gen(
+            "{insn}_local".format(**format_kwargs))
+
+    if nonlocal_storage_name is None:
+        nonlocal_storage_name = var_name_gen(
+            "{insn}_nonlocal".format(**format_kwargs))
+
+    local_scan_insn_id = insn_id_gen(
+            "{iname}_local_scan".format(**format_kwargs))
+
+    nonlocal_scan_insn_id = insn_id_gen(
+            "{iname}_nonlocal_scan".format(**format_kwargs))
+
+    format_kwargs.update({"nonlocal": nonlocal_storage_name})
+
+    nonlocal_init_head_insn_id = insn_id_gen(
+            "{nonlocal}_init_head".format(**format_kwargs))
+
+    nonlocal_init_tail_insn_id = insn_id_gen(
+            "{nonlocal}_init_tail".format(**format_kwargs))
+
+    # }}}
+
+    # Turn the scan into a substitution rule, replace the original scan with a
+    # nop and delete the scan iname.
+    #
+    # (The presence of the scan iname seems to be making precompute very confused.)
+
+    from loopy.transform.data import reduction_arg_to_subst_rule
+    kernel = reduction_arg_to_subst_rule(
+            kernel, scan_iname, subst_rule_name=subst_name)
+
+    from loopy.kernel.instruction import NoOpInstruction
+    # FIXME: this is stupid
+    kernel = _update_instructions(kernel, {insn_id: insn.copy(expression=0)})
+    """
+            {insn_id: NoOpInstruction(
+                id=insn_id,
+                depends_on=insn.depends_on,
+                groups=insn.groups,
+                conflicts_with_groups=insn.groups,
+                no_sync_with=insn.no_sync_with,
+                within_inames_is_final=insn.within_inames_is_final,
+                within_inames=insn.within_inames,
+                priority=insn.priority,
+                boostable=insn.boostable,
+                boostable_into=insn.boostable_into,
+                predicates=insn.predicates,
+                tags=insn.tags)},
+            copy=False)
+    """
+
+    kernel = lp.remove_unused_inames(kernel, inames=(scan_iname,))
+
+    # Make sure we got rid of everything
+    assert scan_iname not in kernel.all_inames()
+
+    # {{{ implement local scan
+
+    from pymbolic import var
+    local_scan_expr = _expand_subst_within_expression(kernel,
+            var(subst_name)(var(outer_local_iname) * inner_length +
+                            var(inner_scan_iname)))
+
+    kernel = lp.split_iname(kernel, sweep_iname, inner_length,
+            inner_iname=inner_iname, outer_iname=outer_iname)
+
+    print("SPLITTING INAME, GOT DOMAINS", kernel.domains)
+
+    from loopy.kernel.data import SubstitutionRule
+    from loopy.symbolic import Reduction
+
+    local_subst = SubstitutionRule(
+            name=local_subst_name,
+            arguments=(outer_iname, inner_iname),
+            expression=Reduction(
+                scan.operation,
+                (inner_scan_iname,),
+                local_scan_expr)
+            )
+
+    substitutions = kernel.substitutions.copy()
+    substitutions[local_subst_name] = local_subst
+
+    kernel = kernel.copy(substitutions=substitutions)
+
+    print(kernel)
+
+    from pymbolic import var
+    kernel = lp.precompute(
+            kernel,
+            [var(local_subst_name)(var(outer_iname), var(inner_iname))],
+            storage_axes=(outer_iname, inner_iname),
+            sweep_inames=(outer_iname, inner_iname),
+            precompute_inames=(outer_local_iname, inner_local_iname),
+            temporary_name=local_storage_name,
+            compute_insn_id=local_scan_insn_id)
+
+    kernel = _add_scan_subdomain(kernel, inner_scan_iname, inner_local_iname)
+
+    # }}}
+
+    # {{{ implement local to nonlocal information transfer
+
+    from loopy.symbolic import pw_aff_to_expr
+    nonlocal_storage_len_pw_aff = (
+            # The 2 here is because the first element is 0.
+            2 + kernel.get_iname_bounds(outer_iname).upper_bound_pw_aff)
+
+    nonlocal_storage_len = pw_aff_to_expr(nonlocal_storage_len_pw_aff)
+
+    if nonlocal_storage_name not in kernel.temporary_variables:
+        from loopy.kernel.data import TemporaryVariable
+        new_temporary_variables = kernel.temporary_variables.copy()
+
+        new_temporary_variables[nonlocal_storage_name] = (
+                TemporaryVariable(
+                    nonlocal_storage_name,
+                    shape=(nonlocal_storage_len,),
+                    scope=lp.auto,
+                    base_indices=lp.auto,
+                    dtype=lp.auto))
+
+        kernel = kernel.copy(temporary_variables=new_temporary_variables)
+
+    insn = kernel.id_to_insn[insn_id]
+
+    # XXX: should not include sweep iname?
+    within_inames = insn.within_inames
+
+    from loopy.kernel.instruction import make_assignment
+    nonlocal_init_head = make_assignment(
+            id=nonlocal_init_head_insn_id,
+            assignees=(var(nonlocal_storage_name)[0],),
+            expression=0,
+            within_inames=frozenset([outer_local_iname]),
+            depends_on=frozenset([local_scan_insn_id]))
+
+    final_element_indices = []
+
+    nonlocal_init_tail = make_assignment(
+            id=nonlocal_init_tail_insn_id,
+            assignees=(var(nonlocal_storage_name)[var(outer_local_iname) + 1],),
+            expression=var(local_storage_name)[var(outer_local_iname),inner_length - 1],
+            within_inames=frozenset([outer_local_iname]),
+            depends_on=frozenset([local_scan_insn_id]))
+
+    kernel = _update_instructions(kernel, (nonlocal_init_head, nonlocal_init_tail), copy=False)
+
+    # }}}
+
+    # {{{ implement nonlocal scan
+
+    kernel.domains.append(_make_slab_set(nonlocal_iname, nonlocal_storage_len))
+
+    kernel = _add_scan_subdomain(kernel, outer_scan_iname, nonlocal_iname)
+    
+    nonlocal_scan = make_assignment(
+            id=nonlocal_scan_insn_id,
+            assignees=(var(nonlocal_storage_name)[var(nonlocal_iname)],),
+            expression=Reduction(
+                scan.operation,
+                (outer_scan_iname,),
+                var(nonlocal_storage_name)[var(outer_scan_iname)]),
+            within_inames=frozenset([nonlocal_iname]),
+            depends_on=frozenset([nonlocal_init_tail_insn_id, nonlocal_init_head_insn_id]))
+
+    kernel = _update_instructions(kernel, (nonlocal_scan,), copy=False)
+
+    # }}}
+
+    # {{{ replace scan with local + nonlocal
+
+    updated_insn = insn.copy(
+        depends_on=insn.depends_on | frozenset([nonlocal_scan_insn_id]),
+        expression=var(nonlocal_storage_name)[var(outer_iname)] + var(local_storage_name)[var(outer_iname), var(inner_iname)])
+
+    kernel = _update_instructions(kernel, (updated_insn,), copy=False)
+
+    # }}}
+
+    return kernel
+
+
+def precompute_scan(
+        kernel, insn_id,
+        sweep_iname,
+        scan_iname,
+        outer_inames=(),
+        temporary_scope=None,
+        temporary_name=None,
+        replace_insn_with_nop=False):
+    """
+    Turn an expression-based scan into an array-based one.
+
+    This takes a reduction of the form::
+
+        [...,sweep_iname] result = reduce(scan_iname, f(scan_iname))
+
+    and does essentially the following transformation::
+
+        [...,sweep_iname'] temp[sweep_iname'] = f(sweep_iname')
+        [...,sweep_iname] temp[sweep_iname] = reduce(scan_iname, temp[scan_iname])
+        [...,sweep_iname] result = temp[sweep_iname]
+
+    Note: this makes an explicit assumption that the sweep iname shares the
+    same bounds as the scan iname and the bounds start at 0.
+    """
+
+    # {{{ sanity checks
+
+    insn = kernel.id_to_insn[insn_id]
+    scan = insn.expression
+    assert scan.inames[0] == scan_iname
+    assert len(scan.inames) == 1
+
+    # }}}
+
+    # {{{ get a stable name for things
+
+    var_name_gen = kernel.get_var_name_generator()
+    insn_id_gen = kernel.get_instruction_id_generator()
+
+    format_kwargs = {"insn": insn_id, "iname": scan_iname}
+
+    orig_subst_name = var_name_gen(
+            "{iname}_orig_subst".format(**format_kwargs))
+
+    scan_subst_name = var_name_gen(
+            "{iname}_subst".format(**format_kwargs))
+
+    precompute_insn = insn_id_gen(
+            "{insn}_precompute".format(**format_kwargs))
+
+    precompute_reduction_insn = insn_id_gen(
+            "{insn}_precompute_reduce".format(**format_kwargs))
+
+    if temporary_name is None:
+        temporary_name = var_name_gen(
+            "{insn}_precompute".format(**format_kwargs))
+
+    # }}}
+
+    from loopy.transform.data import reduction_arg_to_subst_rule
+    kernel = reduction_arg_to_subst_rule(
+            kernel, scan_iname, subst_rule_name=orig_subst_name)
+
+    # {{{ create our own variant of the substitution rule
+
+    # FIXME: There has to be a better way of this.
+
+    orig_subst = kernel.substitutions[orig_subst_name]
+
+    from pymbolic.mapper.substitutor import make_subst_func
+
+    from loopy.symbolic import (
+        SubstitutionRuleMappingContext, RuleAwareSubstitutionMapper)
+
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, var_name_gen)
+
+    from pymbolic import var
+    mapper = RuleAwareSubstitutionMapper(
+            rule_mapping_context,
+            make_subst_func({scan_iname: var(sweep_iname)}),
+            within=lambda *args: True)
+
+    scan_subst = orig_subst.copy(
+            name=scan_subst_name,
+            arguments=outer_inames + (sweep_iname,),
+            expression=mapper(orig_subst.expression, kernel, None))
+
+    substitutions = kernel.substitutions.copy()
+
+    substitutions[scan_subst_name] = scan_subst
+
+    kernel = kernel.copy(substitutions=substitutions)
+
+    # }}}
+
+    print(kernel)
+
+    # FIXME: multi assignments
+    from pymbolic import var
+
+    # FIXME: Make a new precompute iname....
+
+    kernel = lp.precompute(kernel,
+            [var(scan_subst_name)(
+                *(tuple(var(o) for o in outer_inames) +
+                  (var(sweep_iname),)))],
+            sweep_inames=outer_inames + (sweep_iname,),
+            precompute_inames=(sweep_iname,),
+            temporary_name=temporary_name,
+            temporary_scope=temporary_scope,
+            # FIXME: why on earth is this needed
+            compute_insn_id=precompute_insn)
+
+    from loopy.kernel.instruction import make_assignment
+
+    from loopy.symbolic import Reduction
+    precompute_reduction = insn.copy(
+            id=precompute_reduction_insn,
+            assignee=var(temporary_name)[var(sweep_iname)],
+            expression=Reduction(
+                operation=scan.operation,
+                inames=(scan_iname,),
+                exprs=(var(temporary_name)[var(scan_iname)],),
+                allow_simultaneous=False,
+                ),
+            depends_on=insn.depends_on | frozenset([precompute_insn]))
+
+    kernel = kernel.copy(instructions=kernel.instructions +
+                         [precompute_reduction])
+
+    new_insn = insn.copy(
+           expression=var(temporary_name)[var(sweep_iname)],
+           depends_on=
+           frozenset([precompute_reduction_insn]) | insn.depends_on)
+
+    instructions = list(kernel.instructions)
+
+    for i, insn in enumerate(instructions):
+        if insn.id == insn_id:
+            instructions[i] = new_insn
+
+    kernel = kernel.copy(instructions=instructions)
+
+    return kernel
+
+
+# vim: foldmethod=marker
diff --git a/test/test_reduction.py b/test/test_reduction.py
index 290f3d483..96d85beb6 100644
--- a/test/test_reduction.py
+++ b/test/test_reduction.py
@@ -240,12 +240,38 @@ def test_global_parallel_reduction(ctx_factory, size):
     knl = lp.add_dependency(
             knl, "writes:acc_i_outer",
             "id:red_i_outer_arg_barrier")
-
     lp.auto_test_vs_ref(
             ref_knl, ctx, knl, parameters={"n": size},
             print_ref_code=True)
 
 
+def test_global_parallel_reduction_2():
+    knl = lp.make_kernel(
+            "{[i]: 0 <= i < n }",
+            """
+            # Using z[0] instead of z works around a bug in ancient PyOpenCL.
+            z[0] = sum(i, i/13) {id=reduce}
+            """)
+
+    gsize = 128
+    knl = lp.make_two_level_reduction(knl,
+            "reduce",
+            inner_length=gsize * 20,
+            nonlocal_tag="g.0",
+            nonlocal_storage_scope=lp.temp_var_scope.GLOBAL,
+            outer_tag=None,
+            inner_tag=None)
+
+    print(knl)
+
+    knl = lp.split_iname(knl, "i_inner", gsize, outer_tag="l.0")
+    knl = lp.split_reduction_inward(knl, "i_inner_inner")
+
+    knl = lp.realize_reduction(knl)
+
+    print(knl)
+
+
 @pytest.mark.parametrize("size", [1000])
 def test_global_mc_parallel_reduction(ctx_factory, size):
     ctx = ctx_factory()
diff --git a/test/test_scan.py b/test/test_scan.py
index d77a82d59..aabfe3031 100644
--- a/test/test_scan.py
+++ b/test/test_scan.py
@@ -2,7 +2,7 @@ from __future__ import division, absolute_import, print_function
 
 __copyright__ = """
 Copyright (C) 2012 Andreas Kloeckner
-Copyright (C) 2016 Matt Wala
+Copyright (C) 2016, 2017 Matt Wala
 """
 
 __license__ = """
@@ -54,10 +54,8 @@ __all__ = [
 
 # More things to test.
 # - test that dummy inames are removed
-# - nested sequential/parallel scan
 # - scan(a) + scan(b)
 # - global parallel scan
-# - base_exec_iname different bounds from sweep iname
 
 # TO DO:
 # segmented<sum>(...) syntax
@@ -71,11 +69,14 @@ def test_sequential_scan(ctx_factory, n, stride):
 
     knl = lp.make_kernel(
         "[n] -> {[i,j]: 0<=i<n and 0<=j<=%d*i}" % stride,
-        "a[i] = sum(j, j**2) {id=scan}"
+        """
+        a[i] = sum(j, j**2)
+        """
         )
 
     knl = lp.fix_parameters(knl, n=n)
     knl = lp.realize_reduction(knl, force_scan=True)
+
     evt, (a,) = knl(queue)
 
     assert (a.get() == np.cumsum(np.arange(stride*n)**2)[::stride]).all()
@@ -114,6 +115,21 @@ def test_scan_with_different_lower_bound_from_sweep(
 
 
 def test_automatic_scan_detection():
+    knl = lp.make_kernel(
+        [
+            "[n] -> {[i]: 0<=i<n}",
+            "{[j]: 0<=j<=2*i}"
+        ],
+        """
+        a[i] = sum(j, j**2)
+        """
+        )
+
+    cgr = lp.generate_code_v2(knl)
+    assert "tracking" in cgr.device_code()
+
+
+def test_selective_scan_realization():
     pass
 
 
@@ -136,23 +152,41 @@ def test_dependent_domain_scan(ctx_factory):
     assert (a.get() == np.cumsum(np.arange(200)**2)[::2]).all()
 
 
-"""
-def test_nested_scan():
+@pytest.mark.parametrize("i_tag, j_tag", [
+    ("for", "for")
+    ])
+def test_nested_scan(ctx_factory, i_tag, j_tag):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
     knl = lp.make_kernel(
         [
             "[n] -> {[i]: 0 <= i < n}",
-            "{[j]: 0 <= j <= i}",
-            "{[k]: 0 <= j <= k}"
+            "[i] -> {[j]: 0 <= j <= i}",
+            "[i] -> {[k]: 0 <= k <= i}"
         ],
-        "a[i] = sum(j, sum(k, k))")
-"""
+        """
+        <>tmp[i] = sum(k, 1)
+        out[i] = sum(j, tmp[j])
+        """)
 
+    knl = lp.fix_parameters(knl, n=10)
+    knl = lp.tag_inames(knl, dict(i=i_tag, j=j_tag))
+
+    knl = lp.realize_reduction(knl, force_scan=True)
+
+    print(knl)
+
+    evt, (out,) = knl(queue)
 
-def test_scan_unsupported_stride():
+    print(out)
+
+
+def test_scan_not_triangular():
     knl = lp.make_kernel(
         "{[i,j]: 0<=i<100 and 1<=j<=2*i}",
         """
-        a[i] = sum(j, j**2) {id=scan}
+        a[i] = sum(j, j**2)
         """
         )
 
@@ -177,19 +211,11 @@ def test_local_parallel_scan(ctx_factory, n):
     knl = lp.tag_inames(knl, dict(i="l.0"))
     knl = lp.realize_reduction(knl, force_scan=True)
 
-    print(knl)
-
     knl = lp.realize_reduction(knl)
 
     knl = lp.add_dtypes(knl, dict(a=int))
-    c = lp.generate_code_v2(knl)
-
-    print(c.device_code())
 
     evt, (a,) = knl(queue, a=np.arange(16))
-
-    print(a)
-
     assert (a == np.cumsum(np.arange(16)**2)).all()
 
 
@@ -291,8 +317,8 @@ def test_argmax(ctx_factory, i_tag):
     (16, (0, 5)),
     ))
 @pytest.mark.parametrize("iname_tag", ("for", "l.0"))
-def test_segmented_scan(ctx_getter, n, segment_boundaries_indices, iname_tag):
-    ctx = ctx_getter()
+def test_segmented_scan(ctx_factory, n, segment_boundaries_indices, iname_tag):
+    ctx = ctx_factory()
     queue = cl.CommandQueue(ctx)
 
     arr = np.ones(n, dtype=np.float32)
@@ -337,6 +363,37 @@ def test_segmented_scan(ctx_getter, n, segment_boundaries_indices, iname_tag):
     assert [(e == a).all() for e, a in zip(expected, actual)]
 
 
+def test_two_level_scan(ctx_getter):
+    knl = lp.make_kernel(
+        [
+            "{[i,j]: 0 <= i < 256 and 0 <= j <= i}",
+        ],
+        """
+        out[i] = sum(j, j) {id=scan}
+        """,
+        "...")
+
+    #knl = lp.tag_inames(knl, dict(i="l.0"))
+
+    from loopy.transform.reduction import make_two_level_scan
+
+    knl = make_two_level_scan(
+        knl, "scan", inner_length=128,
+        scan_iname="j",
+        sweep_iname="i")
+
+    knl = lp.realize_reduction(knl, force_scan=True)
+
+    print(knl)
+
+    c = ctx_getter()
+    q = cl.CommandQueue(c)
+
+    _, (out,) = knl(q)
+
+    print(out.get())
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab