From cd88b8b134fc8bfac6b61a5881600eaf06922e2d Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Fri, 21 Apr 2017 18:09:31 -0500
Subject: [PATCH] [WIP]

---
 loopy/preprocess.py | 1018 ++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 947 insertions(+), 71 deletions(-)

diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 17226b63a..808273f02 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -274,7 +274,432 @@ def find_temporary_scope(kernel):
 # {{{ rewrite reduction to imperative form
 
 
-# {{{ reduction utils
+# {{{ utils (not stateful)
+
+from collections import namedtuple
+
+
+_InameClassification = namedtuple("_InameClassifiction",
+                                  "sequential, local_parallel, nonlocal_parallel")
+
+
+def _classify_reduction_inames(kernel, inames):
+    sequential = []
+    local_par = []
+    nonlocal_par = []
+
+    from loopy.kernel.data import (
+            LocalIndexTagBase, UnrolledIlpTag, UnrollTag, VectorizeTag,
+            ParallelTag)
+
+    for iname in inames:
+        iname_tag = kernel.iname_to_tag.get(iname)
+
+        if isinstance(iname_tag, (UnrollTag, UnrolledIlpTag)):
+            # These are nominally parallel, but we can live with
+            # them as sequential.
+            sequential.append(iname)
+
+        elif isinstance(iname_tag, LocalIndexTagBase):
+            local_par.append(iname)
+
+        elif isinstance(iname_tag, (ParallelTag, VectorizeTag)):
+            nonlocal_par.append(iname)
+
+        else:
+            sequential.append(iname)
+
+    return _InameClassification(
+            tuple(sequential), tuple(local_par), tuple(nonlocal_par))
+
+
+def _add_params_to_domain(domain, param_names):
+    dim_type = isl.dim_type
+    nparams_orig = domain.dim(dim_type.param)
+    domain = domain.add_dims(dim_type.param, len(param_names))
+
+    for param_idx, param_name in enumerate(param_names):
+        domain = domain.set_dim_name(
+                dim_type.param, param_idx + nparams_orig, param_name)
+
+    return domain
+
+
+def _move_set_to_param_dims_except(domain, except_dims):
+    dim_type = isl.dim_type
+
+    iname_idx = 0
+    for iname in domain.get_var_names(dim_type.set):
+        if iname not in except_dims:
+            domain = domain.move_dims(
+                    dim_type.param, 0,
+                    dim_type.set, iname_idx, 1)
+            iname_idx -= 1
+        iname_idx += 1
+
+    return domain
+
+
+def _domain_depends_on_given_set_dims(domain, set_dim_names):
+    set_dim_names = frozenset(set_dim_names)
+
+    return any(
+            set_dim_names & set(constr.get_coefficients_by_name())
+            for constr in domain.get_constraints())
+
+
+def _check_reduction_is_triangular(kernel, expr, scan_param):
+    """Check whether the reduction within `expr` with scan parameters described by
+    the structure `scan_param` is triangular. This attempts to verify that the
+    domain for the scan and sweep inames is as follows:
+
+    [params] -> {
+        [other inames..., scan_iname, sweep_iname]:
+            (sweep_min_value
+                <= sweep_iname
+                <= sweep_max_value)
+            and
+            (scan_min_value
+                <= scan_iname
+                <= stride * (sweep_iname - sweep_min_value) + scan_min_value)
+            and
+            (irrelevant constraints)
+    }
+    """
+
+    orig_domain = kernel.get_inames_domain(
+            frozenset((scan_param.sweep_iname, scan_param.scan_iname)))
+
+    sweep_iname = scan_param.sweep_iname
+    scan_iname = scan_param.scan_iname
+    affs = isl.affs_from_space(orig_domain.space)
+
+    sweep_lower_bound = isl.align_spaces(
+            scan_param.sweep_lower_bound,
+            affs[0],
+            across_dim_types=True)
+
+    sweep_upper_bound = isl.align_spaces(
+            scan_param.sweep_upper_bound,
+            affs[0],
+            across_dim_types=True)
+
+    scan_lower_bound = isl.align_spaces(
+            scan_param.scan_lower_bound,
+            affs[0],
+            across_dim_types=True)
+
+    from itertools import product
+
+    for (sweep_lb_domain, sweep_lb_aff), \
+        (sweep_ub_domain, sweep_ub_aff), \
+        (scan_lb_domain, scan_lb_aff) in \
+            product(sweep_lower_bound.get_pieces(),
+                    sweep_upper_bound.get_pieces(),
+                    scan_lower_bound.get_pieces()):
+
+        # Assumptions inherited from the domains of the pwaffs
+        assumptions = sweep_lb_domain & sweep_ub_domain & scan_lb_domain
+
+        # Sweep iname constraints
+        hyp_domain = affs[sweep_iname].ge_set(sweep_lb_aff)
+        hyp_domain &= affs[sweep_iname].le_set(sweep_ub_aff)
+
+        # Scan iname constraints
+        hyp_domain &= affs[scan_iname].ge_set(scan_lb_aff)
+        hyp_domain &= affs[scan_iname].le_set(
+                scan_param.stride * (affs[sweep_iname] - sweep_lb_aff)
+                + scan_lb_aff)
+
+        hyp_domain, = (hyp_domain & assumptions).get_basic_sets()
+        test_domain, = (orig_domain & assumptions).get_basic_sets()
+
+        hyp_gist_against_test = hyp_domain.gist(test_domain)
+        if _domain_depends_on_given_set_dims(hyp_gist_against_test,
+                (sweep_iname, scan_iname)):
+            return False, (
+                    "gist of hypothesis against test domain "
+                    "has sweep or scan dependent constraints: '%s'"
+                    % hyp_gist_against_test)
+
+        test_gist_against_hyp = test_domain.gist(hyp_domain)
+        if _domain_depends_on_given_set_dims(test_gist_against_hyp,
+                (sweep_iname, scan_iname)):
+            return False, (
+                   "gist of test against hypothesis domain "
+                   "has sweep or scan dependent constraint: '%s'"
+                   % test_gist_against_hyp)
+
+    return True, "ok"
+
+
+_ScanCandidateParameters = namedtuple(
+        "_ScanCandidateParameters",
+        "sweep_iname, scan_iname, sweep_lower_bound, "
+        "sweep_upper_bound, scan_lower_bound, stride")
+
+
+def _try_infer_scan_candidate_from_expr(
+        kernel, expr, within_inames, sweep_iname=None):
+    """Analyze `expr` and determine if it can be implemented as a scan.
+    """
+    from loopy.symbolic import Reduction
+    assert isinstance(expr, Reduction)
+
+    if len(expr.inames) != 1:
+        raise ValueError(
+                "Multiple inames in reduction: '%s'" % (", ".join(expr.inames),))
+
+    scan_iname, = expr.inames
+
+    from loopy.kernel.tools import DomainChanger
+    dchg = DomainChanger(kernel, (scan_iname,))
+    domain = dchg.get_original_domain()
+
+    if sweep_iname is None:
+        try:
+            sweep_iname = _try_infer_sweep_iname(
+                    domain, scan_iname, kernel.all_inames())
+        except ValueError as v:
+            raise ValueError(
+                    "Couldn't determine a sweep iname for the scan "
+                    "expression '%s': %s" % (expr, v))
+
+    try:
+        sweep_lower_bound, sweep_upper_bound, scan_lower_bound = (
+                _try_infer_scan_and_sweep_bounds(
+                    kernel, scan_iname, sweep_iname, within_inames))
+    except ValueError as v:
+        raise ValueError(
+                "Couldn't determine bounds for the scan with expression '%s' "
+                "(sweep iname: '%s', scan iname: '%s'): %s"
+                % (expr, sweep_iname, scan_iname, v))
+
+    try:
+        stride = _try_infer_scan_stride(
+                kernel, scan_iname, sweep_iname, sweep_lower_bound)
+    except ValueError as v:
+        raise ValueError(
+                "Couldn't determine a scan stride for the scan with expression '%s' "
+                "(sweep iname: '%s', scan iname: '%s'): %s"
+                % (expr, sweep_iname, scan_iname, v))
+
+    return _ScanCandidateParameters(sweep_iname, scan_iname, sweep_lower_bound,
+            sweep_upper_bound, scan_lower_bound, stride)
+
+
+def _try_infer_sweep_iname(domain, scan_iname, candidate_inames):
+    """The sweep iname is the outer iname which guides the scan.
+
+    E.g. for a domain of {[i,j]: 0<=i<n and 0<=j<=i}, i is the sweep iname.
+    """
+    constrs = domain.get_constraints()
+    sweep_iname_candidate = None
+
+    for constr in constrs:
+        candidate_vars = set([
+                var for var in constr.get_var_dict()
+                if var in candidate_inames])
+
+        # Irrelevant constraint - skip
+        if scan_iname not in candidate_vars:
+            continue
+
+        # No additional inames - skip
+        if len(candidate_vars) == 1:
+            continue
+
+        candidate_vars.remove(scan_iname)
+
+        # Depends on more than one iname - error
+        if len(candidate_vars) > 1:
+            raise ValueError(
+                    "More than one sweep iname candidate for scan iname '%s' found "
+                    "(via constraint '%s')" % (scan_iname, constr))
+
+        next_candidate = candidate_vars.pop()
+
+        if sweep_iname_candidate is None:
+            sweep_iname_candidate = next_candidate
+            defining_constraint = constr
+        else:
+            # Check next_candidate consistency
+            if sweep_iname_candidate != next_candidate:
+                raise ValueError(
+                        "More than one sweep iname candidate for scan iname '%s' "
+                        "found (via constraints '%s', '%s')" %
+                        (scan_iname, defining_constraint, constr))
+
+    if sweep_iname_candidate is None:
+        raise ValueError(
+                "Couldn't find any sweep iname candidates for "
+                "scan iname '%s'" % scan_iname)
+
+    return sweep_iname_candidate
+
+
+def _try_infer_scan_and_sweep_bounds(kernel, scan_iname, sweep_iname, within_inames):
+    domain = kernel.get_inames_domain(frozenset((sweep_iname, scan_iname)))
+    domain = _move_set_to_param_dims_except(domain, (sweep_iname, scan_iname))
+
+    var_dict = domain.get_var_dict()
+    sweep_idx = var_dict[sweep_iname][1]
+    scan_idx = var_dict[scan_iname][1]
+
+    domain = domain.project_out_except(
+            within_inames | kernel.non_iname_variable_names(), (isl.dim_type.param,))
+
+    try:
+        with isl.SuppressedWarnings(domain.get_ctx()):
+            sweep_lower_bound = domain.dim_min(sweep_idx)
+            sweep_upper_bound = domain.dim_max(sweep_idx)
+            scan_lower_bound = domain.dim_min(scan_idx)
+    except isl.Error as e:
+        raise ValueError("isl error: %s" % e)
+
+    return (sweep_lower_bound, sweep_upper_bound, scan_lower_bound)
+
+
+def _try_infer_scan_stride(kernel, scan_iname, sweep_iname, sweep_lower_bound):
+    """The stride is the number of steps the scan iname takes per iteration
+    of the sweep iname. This is allowed to be an integer constant.
+
+    E.g. for a domain of {[i,j]: 0<=i<n and 0<=j<=6*i}, the stride is 6.
+    """
+    dim_type = isl.dim_type
+
+    domain = kernel.get_inames_domain(frozenset([sweep_iname, scan_iname]))
+    domain_with_sweep_param = _move_set_to_param_dims_except(domain, (scan_iname,))
+
+    domain_with_sweep_param = domain_with_sweep_param.project_out_except(
+            (sweep_iname, scan_iname), (dim_type.set, dim_type.param))
+
+    scan_iname_idx = domain_with_sweep_param.find_dim_by_name(
+            dim_type.set, scan_iname)
+
+    # Should be equal to k * sweep_iname, where k is the stride.
+
+    try:
+        with isl.SuppressedWarnings(domain_with_sweep_param.get_ctx()):
+            scan_iname_range = (
+                    domain_with_sweep_param.dim_max(scan_iname_idx)
+                    - domain_with_sweep_param.dim_min(scan_iname_idx)
+                    ).gist(domain_with_sweep_param.params())
+    except isl.Error as e:
+        raise ValueError("isl error: '%s'" % e)
+
+    scan_iname_pieces = scan_iname_range.get_pieces()
+
+    if len(scan_iname_pieces) > 1:
+        raise ValueError("range in multiple pieces: %s" % scan_iname_range)
+    elif len(scan_iname_pieces) == 0:
+        raise ValueError("empty range found for iname '%s'" % scan_iname)
+
+    scan_iname_constr, scan_iname_aff = scan_iname_pieces[0]
+
+    if not scan_iname_constr.plain_is_universe():
+        raise ValueError("found constraints: %s" % scan_iname_constr)
+
+    if scan_iname_aff.dim(dim_type.div):
+        raise ValueError("aff has div: %s" % scan_iname_aff)
+
+    coeffs = scan_iname_aff.get_coefficients_by_name(dim_type.param)
+
+    if len(coeffs) == 0:
+        try:
+            scan_iname_aff.get_constant_val()
+        except:
+            raise ValueError("range for aff isn't constant: '%s'" % scan_iname_aff)
+
+        # If this point is reached we're assuming the domain is of the form
+        # {[i,j]: i=0 and j=0}, so the stride is technically 1 - any value
+        # this function returns will be verified later by
+        # _check_reduction_is_triangular().
+        return 1
+
+    if sweep_iname not in coeffs:
+        raise ValueError("didn't find sweep iname in coeffs: %s" % sweep_iname)
+
+    stride = coeffs[sweep_iname]
+
+    if not stride.is_int():
+        raise ValueError("stride not an integer: %s" % stride)
+
+    if not stride.is_pos():
+        raise ValueError("stride not positive: %s" % stride)
+
+    return stride.to_python()
+
+
+def _get_domain_with_iname_as_param(domain, iname):
+    dim_type = isl.dim_type
+
+    if domain.find_dim_by_name(dim_type.param, iname) >= 0:
+        return domain
+
+    iname_idx = domain.find_dim_by_name(dim_type.set, iname)
+
+    assert iname_idx >= 0, (iname, domain)
+
+    return domain.move_dims(
+        dim_type.param, domain.dim(dim_type.param),
+        dim_type.set, iname_idx, 1)
+
+
+def _create_domain_for_sweep_tracking(orig_domain,
+        tracking_iname, sweep_iname, sweep_min_value, scan_min_value, stride):
+    dim_type = isl.dim_type
+
+    subd = isl.BasicSet.universe(orig_domain.params().space)
+
+    # Add tracking_iname and sweep iname.
+
+    subd = _add_params_to_domain(subd, (sweep_iname, tracking_iname))
+
+    # Here we realize the domain:
+    #
+    # [..., i] -> {
+    #  [j]: 0 <= j - l
+    #       and
+    #       j - l <= k * (i - m)
+    #       and
+    #       k * (i - m - 1) < j - l }
+    # where
+    #   * i is the sweep iname
+    #   * j is the tracking iname
+    #   * k is the stride for the scan
+    #   * l is the lower bound for the scan
+    #   * m is the lower bound for the sweep iname
+    #
+    affs = isl.affs_from_space(subd.space)
+
+    subd &= (affs[tracking_iname] - scan_min_value).ge_set(affs[0])
+    subd &= (affs[tracking_iname] - scan_min_value)\
+            .le_set(stride * (affs[sweep_iname] - sweep_min_value))
+    subd &= (affs[tracking_iname] - scan_min_value)\
+            .gt_set(stride * (affs[sweep_iname] - sweep_min_value - 1))
+
+    # Move tracking_iname into a set dim (NOT sweep iname).
+    subd = subd.move_dims(
+            dim_type.set, 0,
+            dim_type.param, subd.dim(dim_type.param) - 1, 1)
+
+    # Simplify (maybe).
+    orig_domain_with_sweep_param = (
+            _get_domain_with_iname_as_param(orig_domain, sweep_iname))
+    subd = subd.gist_params(orig_domain_with_sweep_param.params())
+
+    subd, = subd.get_basic_sets()
+
+    return subd
+
+
+def _strip_if_scalar(reference_exprs, expr):
+    if len(reference_exprs) == 1:
+        return expr[0]
+    else:
+        return expr
+
 
 def _hackily_ensure_multi_assignment_return_values_are_scoped_private(kernel):
     """
@@ -455,10 +880,22 @@ def _hackily_ensure_multi_assignment_return_values_are_scoped_private(kernel):
     return kernel.copy(temporary_variables=new_temporary_variables,
                        instructions=new_instructions)
 
+
+def _insert_subdomain_into_domain_tree(kernel, domains, subdomain):
+    # Intersect with inames, because we could have captured some kernel params
+    # in here too...
+    dependent_inames = (
+            frozenset(subdomain.get_var_names(isl.dim_type.param))
+            & kernel.all_inames())
+    idx, = kernel.get_leaf_domain_indices(dependent_inames)
+    domains.insert(idx + 1, subdomain)
+
 # }}}
 
 
-def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
+def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
+                      automagic_scans_ok=True, force_scan=False,
+                      force_outer_iname_for_scan=None):
     """Rewrites reductions into their imperative form. With *insn_id_filter*
     specified, operate only on the instruction with an instruction id matching
     *insn_id_filter*.
@@ -469,6 +906,17 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
     If *insn_id_filter* is not given, all reductions in all instructions will
     be realized.
+
+    If *automagic_scans_ok*, this function will attempt to rewrite triangular
+    reductions as scans automatically.
+
+    If *force_scan* is *True*, this function will attempt to rewrite *all*
+    candidate reductions as scans and raise an error if this is not possible
+    (this is most useful combined with *insn_id_filter*).
+
+    If *force_outer_iname_for_scan* is not *None*, this function will attempt
+    to realize candidate reductions as scans using the specified iname as the
+    outer (sweep) iname.
     """
 
     logger.debug("%s: realize reduction" % kernel.name)
@@ -530,20 +978,23 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
             reduction_dtypes):
         outer_insn_inames = temp_kernel.insn_inames(insn)
 
-        from pymbolic import var
-        acc_var_names = [
-                var_name_gen("acc_"+"_".join(expr.inames))
-                for i in range(nresults)]
-        acc_vars = tuple(var(n) for n in acc_var_names)
+        from loopy.kernel.data import temp_var_scope
+        acc_var_names = make_temporaries(
+                name_based_on="acc_"+"_".join(expr.inames),
+                nvars=nresults,
+                shape=(),
+                dtypes=reduction_dtypes,
+                scope=temp_var_scope.PRIVATE)
 
-        from loopy.kernel.data import TemporaryVariable, temp_var_scope
+        init_insn_depends_on = frozenset()
 
-        for name, dtype in zip(acc_var_names, reduction_dtypes):
-            new_temporary_variables[name] = TemporaryVariable(
-                    name=name,
-                    shape=(),
-                    dtype=dtype,
-                    scope=temp_var_scope.PRIVATE)
+        global_barrier = temp_kernel.find_most_recent_global_barrier(insn.id)
+
+        if global_barrier is not None:
+            init_insn_depends_on |= frozenset([global_barrier])
+
+        from pymbolic import var
+        acc_vars = tuple(var(n) for n in acc_var_names)
 
         init_id = insn_id_gen(
                 "%s_%s_init" % (insn.id, "_".join(expr.inames)))
@@ -553,7 +1004,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                 assignees=acc_vars,
                 within_inames=outer_insn_inames - frozenset(expr.inames),
                 within_inames_is_final=insn.within_inames_is_final,
-                depends_on=frozenset(),
+                depends_on=init_insn_depends_on,
                 expression=expr.operation.neutral_element(*arg_dtypes))
 
         generated_insns.append(init_insn)
@@ -626,6 +1077,14 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                 v[iname].lt_set(v[0] + size)).get_basic_sets()
         return bs
 
+    def _make_slab_set_from_range(iname, lbound, ubound):
+        v = isl.make_zero_and_vars([iname])
+        bs, = (
+                v[iname].ge_set(v[0] + lbound)
+                &
+                v[iname].lt_set(v[0] + ubound)).get_basic_sets()
+        return bs
+
     def map_reduction_local(expr, rec, nresults, arg_dtypes,
             reduction_dtypes):
         red_iname, = expr.inames
@@ -650,6 +1109,24 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                 _get_int_iname_size(oiname)
                 for oiname in outer_local_inames)
 
+        from loopy.kernel.data import temp_var_scope
+
+        neutral_var_names = make_temporaries(
+                name_based_on="neutral_"+red_iname,
+                nvars=nresults,
+                shape=(),
+                dtypes=reduction_dtypes,
+                scope=temp_var_scope.PRIVATE)
+
+        acc_var_names = make_temporaries(
+                name_based_on="acc_"+red_iname,
+                nvars=nresults,
+                shape=outer_local_iname_sizes + (size,),
+                dtypes=reduction_dtypes,
+                scope=temp_var_scope.LOCAL)
+
+        acc_vars = tuple(var(n) for n in acc_var_names)
+
         # {{{ add separate iname to carry out the reduction
 
         # Doing this sheds any odd conditionals that may be active
@@ -661,32 +1138,9 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
         # }}}
 
-        neutral_var_names = [
-                var_name_gen("neutral_"+red_iname)
-                for i in range(nresults)]
-        acc_var_names = [
-                var_name_gen("acc_"+red_iname)
-                for i in range(nresults)]
-        acc_vars = tuple(var(n) for n in acc_var_names)
-
-        from loopy.kernel.data import TemporaryVariable, temp_var_scope
-        for name, dtype in zip(acc_var_names, reduction_dtypes):
-            new_temporary_variables[name] = TemporaryVariable(
-                    name=name,
-                    shape=outer_local_iname_sizes + (size,),
-                    dtype=dtype,
-                    scope=temp_var_scope.LOCAL)
-        for name, dtype in zip(neutral_var_names, reduction_dtypes):
-            new_temporary_variables[name] = TemporaryVariable(
-                    name=name,
-                    shape=(),
-                    dtype=dtype,
-                    scope=temp_var_scope.PRIVATE)
-
         base_iname_deps = outer_insn_inames - frozenset(expr.inames)
 
         neutral = expr.operation.neutral_element(*arg_dtypes)
-
         init_id = insn_id_gen("%s_%s_init" % (insn.id, red_iname))
         init_insn = make_assignment(
                 id=init_id,
@@ -738,16 +1192,15 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                 expression=expr.operation(
                     arg_dtypes,
                     _strip_if_scalar(
-                        neutral_var_names,
+                        expr.exprs,
                         tuple(var(nvn) for nvn in neutral_var_names)),
-                    reduction_expr),
+                    _strip_if_scalar(expr.exprs, expr.exprs)),
                 within_inames=(
                     (outer_insn_inames - frozenset(expr.inames))
                     | frozenset([red_iname])),
                 within_inames_is_final=insn.within_inames_is_final,
-                depends_on=frozenset(transfer_depends_on) | insn.depends_on,
-                no_sync_with=frozenset(
-                    [(insn_id, "any") for insn_id in transfer_depends_on]))
+                depends_on=frozenset([init_id, init_neutral_id]) | insn.depends_on,
+                no_sync_with=frozenset([(init_id, "any")]))
         generated_insns.append(transfer_insn)
 
         cur_size = 1
@@ -759,6 +1212,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
         istage = 0
         while cur_size > 1:
+
             new_size = cur_size // 2
             assert new_size * 2 == cur_size
 
@@ -798,7 +1252,7 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
         new_insn_add_depends_on.add(prev_id)
         new_insn_add_no_sync_with.add((prev_id, "any"))
-        new_insn_add_within_inames.add(base_exec_iname or stage_exec_iname)
+        new_insn_add_within_inames.add(stage_exec_iname or base_exec_iname)
 
         if nresults == 1:
             assert len(acc_vars) == 1
@@ -807,6 +1261,351 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
             return [acc_var[outer_local_iname_vars + (0,)] for acc_var in acc_vars]
     # }}}
 
+    # {{{ utils (stateful)
+
+    @memoize
+    def get_or_add_sweep_tracking_iname_and_domain(
+            scan_iname, sweep_iname, sweep_min_value, scan_min_value, stride,
+            tracking_iname):
+        domain = temp_kernel.get_inames_domain(frozenset((scan_iname, sweep_iname)))
+
+        inames_added_for_scan.add(tracking_iname)
+
+        new_domain = _create_domain_for_sweep_tracking(domain,
+                tracking_iname, sweep_iname, sweep_min_value, scan_min_value, stride)
+
+        _insert_subdomain_into_domain_tree(temp_kernel, domains, new_domain)
+
+        return tracking_iname
+
+    def replace_var_within_expr(expr, from_var, to_var):
+        from pymbolic.mapper.substitutor import make_subst_func
+
+        from loopy.symbolic import (
+            SubstitutionRuleMappingContext, RuleAwareSubstitutionMapper)
+
+        rule_mapping_context = SubstitutionRuleMappingContext(
+            temp_kernel.substitutions, var_name_gen)
+
+        from pymbolic import var
+        mapper = RuleAwareSubstitutionMapper(
+            rule_mapping_context,
+            make_subst_func({from_var: var(to_var)}),
+            within=lambda *args: True)
+
+        return mapper(expr, temp_kernel, None)
+
+    def make_temporaries(name_based_on, nvars, shape, dtypes, scope):
+        var_names = [
+                var_name_gen(name_based_on.format(index=i))
+                for i in range(nvars)]
+
+        from loopy.kernel.data import TemporaryVariable
+
+        for name, dtype in zip(var_names, dtypes):
+            new_temporary_variables[name] = TemporaryVariable(
+                    name=name,
+                    shape=shape,
+                    dtype=dtype,
+                    scope=scope)
+
+        return var_names
+
+    # }}}
+
+    # {{{ sequential scan
+
+    def map_scan_seq(expr, rec, nresults, arg_dtypes,
+            reduction_dtypes, sweep_iname, scan_iname, sweep_min_value,
+            scan_min_value, stride):
+        outer_insn_inames = temp_kernel.insn_inames(insn)
+        inames_to_remove.add(scan_iname)
+
+        track_iname = var_name_gen(
+                "{sweep_iname}__seq_scan"
+                .format(scan_iname=scan_iname, sweep_iname=sweep_iname))
+
+        get_or_add_sweep_tracking_iname_and_domain(
+                scan_iname, sweep_iname, sweep_min_value, scan_min_value,
+                stride, track_iname)
+
+        from loopy.kernel.data import temp_var_scope
+        acc_var_names = make_temporaries(
+                name_based_on="acc_" + scan_iname,
+                nvars=nresults,
+                shape=(),
+                dtypes=reduction_dtypes,
+                scope=temp_var_scope.PRIVATE)
+
+        from pymbolic import var
+        acc_vars = tuple(var(n) for n in acc_var_names)
+
+        init_id = insn_id_gen(
+                "%s_%s_init" % (insn.id, "_".join(expr.inames)))
+
+        init_insn_depends_on = frozenset()
+
+        global_barrier = temp_kernel.find_most_recent_global_barrier(insn.id)
+
+        if global_barrier is not None:
+            init_insn_depends_on |= frozenset([global_barrier])
+
+        init_insn = make_assignment(
+                id=init_id,
+                assignees=acc_vars,
+                within_inames=outer_insn_inames - frozenset(
+                    (sweep_iname,) + expr.inames),
+                within_inames_is_final=insn.within_inames_is_final,
+                depends_on=init_insn_depends_on,
+                expression=expr.operation.neutral_element(*arg_dtypes))
+
+        generated_insns.append(init_insn)
+
+        updated_inner_exprs = tuple(
+                replace_var_within_expr(sub_expr, scan_iname, track_iname)
+                for sub_expr in expr.exprs)
+
+        update_id = insn_id_gen(
+                based_on="%s_%s_update" % (insn.id, "_".join(expr.inames)))
+
+        update_insn_iname_deps = temp_kernel.insn_inames(insn) | set([track_iname])
+        if insn.within_inames_is_final:
+            update_insn_iname_deps = insn.within_inames | set([track_iname])
+
+        scan_insn = make_assignment(
+                id=update_id,
+                assignees=acc_vars,
+                expression=expr.operation(
+                    arg_dtypes,
+                    _strip_if_scalar(acc_vars, acc_vars),
+                    _strip_if_scalar(acc_vars, updated_inner_exprs)),
+                depends_on=frozenset([init_insn.id]) | insn.depends_on,
+                within_inames=update_insn_iname_deps,
+                no_sync_with=insn.no_sync_with,
+                within_inames_is_final=insn.within_inames_is_final)
+
+        generated_insns.append(scan_insn)
+
+        new_insn_add_depends_on.add(scan_insn.id)
+
+        if nresults == 1:
+            assert len(acc_vars) == 1
+            return acc_vars[0]
+        else:
+            return acc_vars
+
+    # }}}
+
+    # {{{ local-parallel scan
+
+    def map_scan_local(expr, rec, nresults, arg_dtypes,
+            reduction_dtypes, sweep_iname, scan_iname,
+            sweep_min_value, scan_min_value, stride):
+
+        scan_size = _get_int_iname_size(sweep_iname)
+
+        assert scan_size > 0
+
+        if scan_size == 1:
+            return map_reduction_seq(
+                    expr, rec, nresults, arg_dtypes, reduction_dtypes)
+
+        outer_insn_inames = temp_kernel.insn_inames(insn)
+
+        from loopy.kernel.data import LocalIndexTagBase
+        outer_local_inames = tuple(
+                oiname
+                for oiname in outer_insn_inames
+                if isinstance(
+                    kernel.iname_to_tag.get(oiname),
+                    LocalIndexTagBase)
+                and oiname != sweep_iname)
+
+        from pymbolic import var
+        outer_local_iname_vars = tuple(
+                var(oiname) for oiname in outer_local_inames)
+
+        outer_local_iname_sizes = tuple(
+                _get_int_iname_size(oiname)
+                for oiname in outer_local_inames)
+
+        track_iname = var_name_gen(
+                "{sweep_iname}__pre_scan"
+                .format(scan_iname=scan_iname, sweep_iname=sweep_iname))
+
+        get_or_add_sweep_tracking_iname_and_domain(
+                scan_iname, sweep_iname, sweep_min_value, scan_min_value, stride,
+                track_iname)
+
+        # {{{ add separate iname to carry out the scan
+
+        # Doing this sheds any odd conditionals that may be active
+        # on our scan_iname.
+
+        base_exec_iname = var_name_gen(sweep_iname + "__scan")
+        domains.append(_make_slab_set(base_exec_iname, scan_size))
+        new_iname_tags[base_exec_iname] = kernel.iname_to_tag[sweep_iname]
+
+        # }}}
+
+        from loopy.kernel.data import temp_var_scope
+
+        read_var_names = make_temporaries(
+                name_based_on="read_"+scan_iname+"_arg_{index}",
+                nvars=nresults,
+                shape=(),
+                dtypes=reduction_dtypes,
+                scope=temp_var_scope.PRIVATE)
+
+        acc_var_names = make_temporaries(
+                name_based_on="acc_"+scan_iname,
+                nvars=nresults,
+                shape=outer_local_iname_sizes + (scan_size,),
+                dtypes=reduction_dtypes,
+                scope=temp_var_scope.LOCAL)
+
+        acc_vars = tuple(var(n) for n in acc_var_names)
+        read_vars = tuple(var(n) for n in read_var_names)
+
+        base_iname_deps = (outer_insn_inames
+                - frozenset(expr.inames) - frozenset([sweep_iname]))
+
+        neutral = expr.operation.neutral_element(*arg_dtypes)
+
+        init_insn_depends_on = insn.depends_on
+
+        global_barrier = temp_kernel.find_most_recent_global_barrier(insn.id)
+
+        if global_barrier is not None:
+            init_insn_depends_on |= frozenset([global_barrier])
+
+        init_id = insn_id_gen("%s_%s_init" % (insn.id, scan_iname))
+        init_insn = make_assignment(
+                id=init_id,
+                assignees=tuple(
+                    acc_var[outer_local_iname_vars + (var(base_exec_iname),)]
+                    for acc_var in acc_vars),
+                expression=neutral,
+                within_inames=base_iname_deps | frozenset([base_exec_iname]),
+                within_inames_is_final=insn.within_inames_is_final,
+                depends_on=init_insn_depends_on)
+        generated_insns.append(init_insn)
+
+        updated_inner_exprs = tuple(
+                replace_var_within_expr(sub_expr, scan_iname, track_iname)
+                for sub_expr in expr.exprs)
+
+        from loopy.symbolic import Reduction
+
+        from loopy.symbolic import pw_aff_to_expr
+        sweep_min_value_expr = pw_aff_to_expr(sweep_min_value)
+
+        transfer_id = insn_id_gen("%s_%s_transfer" % (insn.id, scan_iname))
+        transfer_insn = make_assignment(
+                id=transfer_id,
+                assignees=tuple(
+                    acc_var[outer_local_iname_vars
+                            + (var(sweep_iname) - sweep_min_value_expr,)]
+                    for acc_var in acc_vars),
+                expression=Reduction(
+                    operation=expr.operation,
+                    inames=(track_iname,),
+                    exprs=updated_inner_exprs,
+                    allow_simultaneous=False,
+                    ),
+                within_inames=outer_insn_inames - frozenset(expr.inames),
+                within_inames_is_final=insn.within_inames_is_final,
+                depends_on=frozenset([init_id]) | insn.depends_on,
+                no_sync_with=frozenset([(init_id, "any")]) | insn.no_sync_with)
+
+        generated_insns.append(transfer_insn)
+
+        def _strip_if_scalar(c):
+            if len(acc_vars) == 1:
+                return c[0]
+            else:
+                return c
+
+        prev_id = transfer_id
+
+        istage = 0
+        cur_size = 1
+
+        while cur_size < scan_size:
+            stage_exec_iname = var_name_gen("%s__scan_s%d" % (sweep_iname, istage))
+            domains.append(
+                    _make_slab_set_from_range(stage_exec_iname, cur_size, scan_size))
+            new_iname_tags[stage_exec_iname] = kernel.iname_to_tag[sweep_iname]
+
+            for read_var, acc_var in zip(read_vars, acc_vars):
+                read_stage_id = insn_id_gen(
+                        "scan_%s_read_stage_%d" % (scan_iname, istage))
+
+                read_stage_insn = make_assignment(
+                        id=read_stage_id,
+                        assignees=(read_var,),
+                        expression=(
+                                acc_var[
+                                    outer_local_iname_vars
+                                    + (var(stage_exec_iname) - cur_size,)]),
+                        within_inames=(
+                            base_iname_deps | frozenset([stage_exec_iname])),
+                        within_inames_is_final=insn.within_inames_is_final,
+                        depends_on=frozenset([prev_id]))
+
+                if cur_size == 1:
+                    # Performance hack: don't add a barrier here with transfer_insn.
+                    # NOTE: This won't work if the way that local inames
+                    # are lowered changes.
+                    read_stage_insn = read_stage_insn.copy(
+                            no_sync_with=(
+                                read_stage_insn.no_sync_with
+                                | frozenset([(transfer_id, "any")])))
+
+                generated_insns.append(read_stage_insn)
+                prev_id = read_stage_id
+
+            write_stage_id = insn_id_gen(
+                    "scan_%s_write_stage_%d" % (scan_iname, istage))
+            write_stage_insn = make_assignment(
+                    id=write_stage_id,
+                    assignees=tuple(
+                        acc_var[outer_local_iname_vars + (var(stage_exec_iname),)]
+                        for acc_var in acc_vars),
+                    expression=expr.operation(
+                        arg_dtypes,
+                        _strip_if_scalar(read_vars),
+                        _strip_if_scalar(tuple(
+                            acc_var[
+                                outer_local_iname_vars + (var(stage_exec_iname),)]
+                            for acc_var in acc_vars))
+                        ),
+                    within_inames=(
+                        base_iname_deps | frozenset([stage_exec_iname])),
+                    within_inames_is_final=insn.within_inames_is_final,
+                    depends_on=frozenset([prev_id]),
+                    )
+
+            generated_insns.append(write_stage_insn)
+            prev_id = write_stage_id
+
+            cur_size *= 2
+            istage += 1
+
+        new_insn_add_depends_on.add(prev_id)
+        new_insn_add_within_inames.add(sweep_iname)
+
+        output_idx = var(sweep_iname) - sweep_min_value_expr
+
+        if nresults == 1:
+            assert len(acc_vars) == 1
+            return acc_vars[0][outer_local_iname_vars + (output_idx,)]
+        else:
+            return [acc_var[outer_local_iname_vars + (output_idx,)]
+                    for acc_var in acc_vars]
+
+    # }}}
+
     # {{{ seq/par dispatch
 
     def map_reduction(expr, rec, nresults=1):
@@ -825,31 +1624,43 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
             raise LoopyError("reduction used within loop(s) that it was "
                     "supposed to reduce over: " + ", ".join(bad_inames))
 
-        n_sequential = 0
-        n_local_par = 0
+        iname_classes = _classify_reduction_inames(temp_kernel, expr.inames)
 
-        from loopy.kernel.data import (
-                LocalIndexTagBase, UnrolledIlpTag, UnrollTag, VectorizeTag,
-                ParallelTag)
-        for iname in expr.inames:
-            iname_tag = kernel.iname_to_tag.get(iname)
+        n_sequential = len(iname_classes.sequential)
+        n_local_par = len(iname_classes.local_parallel)
+        n_nonlocal_par = len(iname_classes.nonlocal_parallel)
 
-            if isinstance(iname_tag, (UnrollTag, UnrolledIlpTag)):
-                # These are nominally parallel, but we can live with
-                # them as sequential.
-                n_sequential += 1
+        really_force_scan = force_scan and (
+            len(expr.inames) != 1 or expr.inames[0] not in inames_added_for_scan)
 
-            elif isinstance(iname_tag, LocalIndexTagBase):
-                n_local_par += 1
+        def _error_if_force_scan_on(cls, msg):
+            if really_force_scan:
+                raise cls(msg)
 
-            elif isinstance(iname_tag, (ParallelTag, VectorizeTag)):
-                raise LoopyError("the only form of parallelism supported "
-                        "by reductions is 'local'--found iname '%s' "
-                        "tagged '%s'"
-                        % (iname, type(iname_tag).__name__))
+        may_be_implemented_as_scan = False
+        if force_scan or automagic_scans_ok:
+            from loopy.diagnostic import ReductionIsNotTriangularError
+
+            try:
+                # Try to determine scan candidate information (sweep iname, scan
+                # iname, etc).
+                scan_param = _try_infer_scan_candidate_from_expr(
+                        temp_kernel, expr, outer_insn_inames,
+                        sweep_iname=force_outer_iname_for_scan)
+
+            except ValueError as v:
+                error = str(v)
 
             else:
-                n_sequential += 1
+                # Ensures the reduction is triangular (somewhat expensive).
+                may_be_implemented_as_scan, error = (
+                        _check_reduction_is_triangular(
+                            temp_kernel, expr, scan_param))
+
+            if not may_be_implemented_as_scan:
+                _error_if_force_scan_on(ReductionIsNotTriangularError, error)
+
+        # {{{ sanity checks
 
         if n_local_par and n_sequential:
             raise LoopyError("Reduction over '%s' contains both parallel and "
@@ -865,21 +1676,84 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
                     "before code generation."
                     % ", ".join(expr.inames))
 
-        if n_sequential:
-            assert n_local_par == 0
-            return map_reduction_seq(expr, rec, nresults, arg_dtypes,
-                    reduction_dtypes)
-        elif n_local_par:
-            return map_reduction_local(expr, rec, nresults, arg_dtypes,
-                    reduction_dtypes)
-        else:
+        if n_nonlocal_par:
+            bad_inames = iname_classes.nonlocal_parallel
+            raise LoopyError("the only form of parallelism supported "
+                    "by reductions is 'local'--found iname(s) '%s' "
+                    "respectively tagged '%s'"
+                    % (", ".join(bad_inames),
+                       ", ".join(kernel.iname_to_tag[iname]
+                                 for iname in bad_inames)))
+
+        if n_local_par == 0 and n_sequential == 0:
             from loopy.diagnostic import warn_with_kernel
             warn_with_kernel(kernel, "empty_reduction",
                     "Empty reduction found (no inames to reduce over). "
                     "Eliminating.")
 
+            # FIXME: return a neutral element.
+
             return expr.expr
 
+        # }}}
+
+        if may_be_implemented_as_scan:
+            assert force_scan or automagic_scans_ok
+
+            # We require the "scan" iname to be tagged sequential.
+            if n_sequential:
+                sweep_iname = scan_param.sweep_iname
+                sweep_class = _classify_reduction_inames(kernel, (sweep_iname,))
+
+                sequential = sweep_iname in sweep_class.sequential
+                parallel = sweep_iname in sweep_class.local_parallel
+                bad_parallel = sweep_iname in sweep_class.nonlocal_parallel
+
+                if sweep_iname not in outer_insn_inames:
+                    _error_if_force_scan_on(LoopyError,
+                            "Sweep iname '%s' was detected, but is not an iname "
+                            "for the instruction." % sweep_iname)
+                elif bad_parallel:
+                    _error_if_force_scan_on(LoopyError,
+                            "Sweep iname '%s' has an unsupported parallel tag '%s' "
+                            "- the only parallelism allowed is 'local'." %
+                            (sweep_iname, temp_kernel.iname_to_tag[sweep_iname]))
+                elif parallel:
+                    return map_scan_local(
+                            expr, rec, nresults, arg_dtypes, reduction_dtypes,
+                            sweep_iname, scan_param.scan_iname,
+                            scan_param.sweep_lower_bound,
+                            scan_param.scan_lower_bound,
+                            scan_param.stride)
+                elif sequential:
+                    return map_scan_seq(
+                            expr, rec, nresults, arg_dtypes, reduction_dtypes,
+                            sweep_iname, scan_param.scan_iname,
+                            scan_param.sweep_lower_bound,
+                            scan_param.scan_lower_bound,
+                            scan_param.stride)
+
+                # fallthrough to reduction implementation
+
+            else:
+                assert n_local_par > 0
+                scan_iname, = expr.inames
+                _error_if_force_scan_on(LoopyError,
+                        "Scan iname '%s' is parallel tagged: this is not allowed "
+                        "(only the sweep iname should be tagged if parallelism "
+                        "is desired)." % scan_iname)
+
+                # fallthrough to reduction implementation
+
+        if n_sequential:
+            assert n_local_par == 0
+            return map_reduction_seq(
+                    expr, rec, nresults, arg_dtypes, reduction_dtypes)
+        else:
+            assert n_local_par > 0
+            return map_reduction_local(
+                    expr, rec, nresults, arg_dtypes, reduction_dtypes)
+
     # }}}
 
     from loopy.symbolic import ReductionCallbackMapper
@@ -985,6 +1859,8 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True):
 
     kernel = lp.tag_inames(kernel, new_iname_tags)
 
+    # TODO: remove unused inames...
+
     kernel = (
             _hackily_ensure_multi_assignment_return_values_are_scoped_private(
                 kernel))
-- 
GitLab