From d39c0fb3e19b96aa69d9edf1fc1019e90bcba596 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Wed, 8 Mar 2017 00:24:21 -0600
Subject: [PATCH] [ci skip] Two level scan now can be done in parallel.

---
 loopy/kernel/__init__.py       |  89 +++++++++
 loopy/kernel/tools.py          |   1 +
 loopy/preprocess.py            |  17 +-
 loopy/schedule/tools.py        |  11 --
 loopy/transform/instruction.py |  69 ++++++-
 loopy/transform/reduction.py   | 321 ++++++++++++++-------------------
 loopy/transform/save.py        | 171 +++++++++++-------
 test/test_loopy.py             |  44 +++++
 test/test_scan.py              |  32 +++-
 test/test_transform.py         |   4 +
 10 files changed, 478 insertions(+), 281 deletions(-)

diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py
index 793d31791..dfe9c857c 100644
--- a/loopy/kernel/__init__.py
+++ b/loopy/kernel/__init__.py
@@ -823,6 +823,95 @@ class LoopKernel(ImmutableRecordWithoutPickling):
 
         return result
 
+    @property
+    @memoize_method
+    def global_barrier_order(self):
+        """Return a :class:`tuple` of the listing the ids of global barrier instructions
+        as they appear in order in the kernel.
+
+        See also :class:`loopy.instruction.BarrierInstruction`.
+        """
+        barriers = []
+        visiting = set()
+        visited = set()
+
+        unvisited = set(insn.id for insn in self.instructions)
+
+        while unvisited:
+            stack = [unvisited.pop()]
+
+            while stack:
+                top = stack[-1]
+
+                if top in visiting:
+                    visiting.remove(top)
+
+                    from loopy.kernel.instruction import BarrierInstruction
+                    insn = self.id_to_insn[top]
+                    if isinstance(insn, BarrierInstruction):
+                        if insn.kind == "global":
+                            barriers.append(top)
+
+                if top in visited:
+                    stack.pop()
+                    continue
+
+                visited.add(top)
+                visiting.add(top)
+
+                for child in self.id_to_insn[top].depends_on:
+                    # Check for no cycles.
+                    assert child not in visiting
+                    stack.append(child)
+
+        # Ensure this is the only possible order.
+        for prev_barrier, barrier in zip(barriers, barriers[1:]):
+            if prev_barrier not in self.recursive_insn_dep_map()[barrier]:
+                raise LoopyError(
+                        "Unordered global barriers detected: '%s', '%s'"
+                        % (barrier, prev_barrier))
+
+        return tuple(barriers)
+
+    @memoize_method
+    def find_most_recent_global_barrier(self, insn_id):
+        """Return the id of the latest occuring global barrier which the
+        given instruction (indirectly or directly) depends on, or *None* if this
+        instruction does not depend on a global barrier.
+
+        The return value is guaranteed to be unique because global barriers are
+        totally ordered within the kernel.
+        """
+
+        if len(self.global_barrier_order) == 0:
+            return None
+
+        insn = self.id_to_insn[insn_id]
+
+        if len(insn.depends_on) == 0:
+            return None
+
+        def is_barrier(my_insn_id):
+            insn = self.id_to_insn[my_insn_id]
+            from loopy.kernel.instruction import BarrierInstruction
+            return isinstance(insn, BarrierInstruction) and insn.kind == "global"
+
+        global_barrier_to_ordinal = dict(
+            (b, i) for i, b in enumerate(self.global_barrier_order))
+
+        def get_barrier_ordinal(barrier_id):
+            return global_barrier_to_ordinal[barrier_id] if barrier_id is not None else -1
+
+        direct_barrier_dependencies = set(
+                dep for dep in insn.depends_on if is_barrier(dep))
+
+        if len(direct_barrier_dependencies) > 0:
+            return max(direct_barrier_dependencies, key=get_barrier_ordinal)
+        else:
+            return max((self.find_most_recent_global_barrier(dep)
+                        for dep in insn.depends_on),
+                    key=get_barrier_ordinal)
+
     # }}}
 
     # {{{ argument wrangling
diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py
index 539bfbed0..d94136e43 100644
--- a/loopy/kernel/tools.py
+++ b/loopy/kernel/tools.py
@@ -1354,4 +1354,5 @@ def draw_dependencies_as_unicode_arrows(
 
 # }}}
 
+
 # vim: foldmethod=marker
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index f139810f1..ef49faa33 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -1242,13 +1242,20 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
         init_id = insn_id_gen(
                 "%s_%s_init" % (insn.id, "_".join(expr.inames)))
 
+        init_insn_depends_on = frozenset()
+
+        global_barrier = temp_kernel.find_most_recent_global_barrier(insn.id)
+
+        if global_barrier is not None:
+            init_insn_depends_on |= frozenset([global_barrier])
+
         init_insn = make_assignment(
                 id=init_id,
                 assignees=acc_vars,
                 within_inames=outer_insn_inames - frozenset(
                     (sweep_iname,) + expr.inames),
                 within_inames_is_final=insn.within_inames_is_final,
-                depends_on=frozenset(),
+                depends_on=init_insn_depends_on,
                 expression=expr.operation.neutral_element(*arg_dtypes))
 
         generated_insns.append(init_insn)
@@ -1257,11 +1264,6 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
                 replace_var_within_expr(sub_expr, scan_iname, track_iname)
                 for sub_expr in expr.exprs)
 
-        """
-        updated_inames = tuple(
-                (set(expr.inames) - set([scan_iname])) | set([track_iname]))
-        """
-
         update_id = insn_id_gen(
                 based_on="%s_%s_update" % (insn.id, "_".join(expr.inames)))
 
@@ -1600,9 +1602,8 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
                     _error_if_force_scan_on(LoopyError,
                             "Sweep iname '%s' has an unsupported parallel tag '%s' "
                             "- the only parallelism allowed is 'local'." %
-                            (sweep_iname, sweep_class.nonlocal_parallel[0]))
+                            (sweep_iname, temp_kernel.iname_to_tag[sweep_iname]))
                 elif parallel:
-                    print(temp_kernel)
                     return map_scan_local(
                             expr, rec, nresults, arg_dtypes, reduction_dtypes,
                             sweep_iname, scan_param.scan_iname,
diff --git a/loopy/schedule/tools.py b/loopy/schedule/tools.py
index 5de677e72..692e39028 100644
--- a/loopy/schedule/tools.py
+++ b/loopy/schedule/tools.py
@@ -144,17 +144,6 @@ class InstructionQuery(object):
                    if isinstance(self.kernel.iname_to_tag.get(iname),
                                  HardwareParallelTag))
 
-    @memoize_method
-    def common_hw_inames(self, insn_ids):
-        """
-        Return the common set of hardware parallel tagged inames among
-        the list of instructions.
-        """
-        # Get the list of hardware inames in which the temporary is defined.
-        if len(insn_ids) == 0:
-            return set()
-        return set.intersection(*(self.hw_inames(id) for id in insn_ids))
-
 # }}}
 
 
diff --git a/loopy/transform/instruction.py b/loopy/transform/instruction.py
index 7c9c96886..9143052a4 100644
--- a/loopy/transform/instruction.py
+++ b/loopy/transform/instruction.py
@@ -34,7 +34,6 @@ def find_instructions(kernel, insn_match):
     match = parse_match(insn_match)
     return [insn for insn in kernel.instructions if match(kernel, insn)]
 
-
 # }}}
 
 
@@ -207,4 +206,72 @@ def tag_instructions(kernel, new_tag, within=None):
 # }}}
 
 
+# {{{ add nosync
+
+def add_nosync_to_instructions(
+        kernel, scope, source, sink, bidirectional=False):
+    """Add a *nosync* directive between *source* and *sync*.
+
+    *source* and *sink* may be any instruction id match understood by
+    :func:`loopy.match.parse_match`.
+
+    *scope* should be a valid nosync scope.
+
+    If *bidirectional* is True, this adds a nosync to both the source
+    and sink instructions, otherwise the directive is only added to the
+    sink instructions.
+
+    *nosync* attributes are only added if a dependency is present or if
+    the instruction pair is spread across a conflicting group.
+    """
+
+    if isinstance(source, str) and source in kernel.id_to_insn:
+        sources = frozenset([source])
+    else:
+        sources = frozenset(
+                source.id for source in find_instructions(kernel, source))
+
+    if isinstance(sink, str) and sink in kernel.id_to_insn:
+        sinks = frozenset([sink])
+    else:
+        sinks = frozenset(
+                sink.id for sink in find_instructions(kernel, sink))
+
+    def insns_in_conflicting_groups(insn1_id, insn2_id):
+        insn1 = kernel.id_to_insn[insn1_id]
+        insn2 = kernel.id_to_insn[insn2_id]
+        return (
+                bool(insn1.groups & insn2.conflicts_with_groups)
+                or
+                bool(insn2.groups & insn1.conflicts_with_groups))
+
+    from collections import defaultdict
+    nosync_to_add = defaultdict(lambda: set())
+
+    for sink in sinks:
+        for source in sources:
+
+            needs_nosync = (
+                    source in kernel.recursive_insn_dep_map()[sink]
+                    or insns_in_conflicting_groups(source, sink))
+
+            if not needs_nosync:
+                continue
+
+            nosync_to_add[sink].add((source, scope))
+            if bidirectional:
+                nosync_to_add[source].add((sink, scope))
+
+    new_instructions = list(kernel.instructions)
+
+    for i, insn in enumerate(new_instructions):
+        if insn.id in nosync_to_add:
+            new_instructions[i] = insn.copy(
+                    no_sync_with=insn.no_sync_with | frozenset(nosync_to_add[insn.id]))
+
+    return kernel.copy(instructions=new_instructions)
+
+# }}}
+
+
 # vim: foldmethod=marker
diff --git a/loopy/transform/reduction.py b/loopy/transform/reduction.py
index 1693fb515..2fd086912 100644
--- a/loopy/transform/reduction.py
+++ b/loopy/transform/reduction.py
@@ -40,7 +40,6 @@ __doc__ = """
 
 .. autofunction:: make_two_level_reduction
 .. autofunction:: make_two_level_scan
-.. autofunction:: precompute_scan
 """
 
 
@@ -153,7 +152,6 @@ def _make_slab_set(iname, size):
             v[0].le_set(v[iname])
             &
             v[iname].lt_set(v[0] + size)).get_basic_sets()
-    print("ADDING SLAB", bs)
     return bs
 
 
@@ -197,6 +195,22 @@ def _expand_subst_within_expression(kernel, expr):
     return submap(expr, kernel, insn=None)
 
 
+def _add_global_barrier(kernel, source, sink, barrier_id):
+    from loopy.kernel.instruction import BarrierInstruction
+    barrier_insn = BarrierInstruction(
+            id=barrier_id,
+            depends_on=frozenset([source]),
+            kind="global")
+
+    updated_sink = kernel.id_to_insn[sink]
+    updated_sink = updated_sink.copy(
+            depends_on=updated_sink.depends_on | frozenset([barrier_id]))
+
+    kernel = _update_instructions(kernel, (barrier_insn, updated_sink), copy=True)
+
+    return kernel
+
+
 def make_two_level_scan(
         kernel, insn_id,
         scan_iname,
@@ -212,8 +226,8 @@ def make_two_level_scan(
         inner_local_tag=None,
         inner_tag=None,
         outer_tag=None,
-        inner_local_iname=None,
-        outer_local_iname=None):
+        inner_iname=None,
+        outer_iname=None):
     """
     Two level scan, mediated through a "local" and "nonlocal" array.
 
@@ -232,6 +246,8 @@ def make_two_level_scan(
 
     # {{{ sanity checks
 
+    # FIXME: More sanity checks...
+
     insn = kernel.id_to_insn[insn_id]
     scan = insn.expression
     assert scan.inames[0] == scan_iname
@@ -241,34 +257,42 @@ def make_two_level_scan(
 
     # {{{ get stable names for everything
 
+    # XXX: add inner_iname and outer_iname to var_name_gen if not none
+
     var_name_gen = kernel.get_var_name_generator()
     insn_id_gen = kernel.get_instruction_id_generator()
 
-    format_kwargs = {"insn": insn_id, "iname": scan_iname, "sweep": sweep_iname}
+    level = 0 #scan_level or try_get_scan_level(sweep_iname)
+
+    format_kwargs = {
+            "insn": insn_id, "iname": scan_iname, "sweep": sweep_iname,
+            "level": level, "next_level": level + 1, "prefix": "l"}
 
     nonlocal_storage_name = var_name_gen(
-            "{insn}_nonlocal".format(**format_kwargs))
+            "{prefix}{level}_insn".format(**format_kwargs))
+
+    if inner_iname is None:
+        inner_iname = var_name_gen(
+                "{prefix}{level}_inner_update_{sweep}".format(**format_kwargs))
+
+    if outer_iname is None:
+        outer_iname = var_name_gen(
+                "{prefix}{level}_outer_update_{sweep}".format(**format_kwargs))
 
-    inner_iname = var_name_gen(
-            "{sweep}_inner".format(**format_kwargs))
-    outer_iname = var_name_gen(
-            "{sweep}_outer".format(**format_kwargs))
     nonlocal_iname = var_name_gen(
-            "{sweep}_nonlocal".format(**format_kwargs))
+            "{prefix}{level}_combine_{sweep}".format(**format_kwargs))
 
-    if inner_local_iname is None:
-        inner_local_iname = var_name_gen(
-                "{sweep}_inner_local".format(**format_kwargs))
+    inner_local_iname = var_name_gen(
+            "{prefix}{next_level}_inner_{sweep}".format(**format_kwargs))
 
     inner_scan_iname = var_name_gen(
-            "{iname}_inner".format(**format_kwargs))
+            "{prefix}{next_level}_{iname}".format(**format_kwargs))
 
     outer_scan_iname = var_name_gen(
-            "{iname}_outer".format(**format_kwargs))
+            "{prefix}{level}_{iname}".format(**format_kwargs))
 
-    if outer_local_iname is None:
-        outer_local_iname = var_name_gen(
-                "{sweep}_outer_local".format(**format_kwargs))
+    outer_local_iname = var_name_gen(
+            "{prefix}{next_level}_outer_{sweep}".format(**format_kwargs))
 
     subst_name = var_name_gen(
             "{insn}_inner_subst".format(**format_kwargs))
@@ -278,11 +302,11 @@ def make_two_level_scan(
 
     if local_storage_name is None:
         local_storage_name = var_name_gen(
-            "{insn}_local".format(**format_kwargs))
+            "{prefix}{next_level}_{insn}".format(**format_kwargs))
 
     if nonlocal_storage_name is None:
         nonlocal_storage_name = var_name_gen(
-            "{insn}_nonlocal".format(**format_kwargs))
+            "{prefix}{level}_{insn}".format(**format_kwargs))
 
     local_scan_insn_id = insn_id_gen(
             "{iname}_local_scan".format(**format_kwargs))
@@ -300,6 +324,27 @@ def make_two_level_scan(
 
     # }}}
 
+    # {{{ utils
+
+    if local_storage_axes is None:
+        local_storage_axes = (outer_iname, inner_iname)
+
+    def pick_out_relevant_axes(full_indices, strip_scalar=False):
+        assert len(full_indices) == 2
+        iname_to_index = dict(zip((outer_iname, inner_iname), full_indices))
+
+        result = []
+        for iname in local_storage_axes:
+            result.append(iname_to_index[iname])
+
+        assert len(result) > 0
+
+        return tuple(result) if not (strip_scalar and len(result) == 1) else result[0]
+
+    # }}}
+
+    # {{{ prepare for two level scan
+
     # Turn the scan into a substitution rule, replace the original scan with a
     # nop and delete the scan iname.
     #
@@ -334,6 +379,8 @@ def make_two_level_scan(
     # Make sure we got rid of everything
     assert scan_iname not in kernel.all_inames()
 
+    # }}}
+
     # {{{ implement local scan
 
     from pymbolic import var
@@ -342,9 +389,16 @@ def make_two_level_scan(
                             var(inner_scan_iname)))
 
     kernel = lp.split_iname(kernel, sweep_iname, inner_length,
-            inner_iname=inner_iname, outer_iname=outer_iname)
+            inner_iname=inner_iname, outer_iname=outer_iname,
+            inner_tag=inner_tag, outer_tag=outer_tag)
 
-    print("SPLITTING INAME, GOT DOMAINS", kernel.domains)
+    kernel = lp.duplicate_inames(kernel,
+            (outer_iname, inner_iname),
+            within="not id:*",
+            new_inames=[outer_local_iname, inner_local_iname],
+            tags={outer_iname: outer_local_tag, inner_iname: inner_local_tag})
+
+    kernel = _add_scan_subdomain(kernel, inner_scan_iname, inner_local_iname)
 
     from loopy.kernel.data import SubstitutionRule
     from loopy.symbolic import Reduction
@@ -353,37 +407,39 @@ def make_two_level_scan(
             name=local_subst_name,
             arguments=(outer_iname, inner_iname),
             expression=Reduction(
-                scan.operation,
-                (inner_scan_iname,),
-                local_scan_expr)
-            )
+                scan.operation, (inner_scan_iname,), local_scan_expr))
 
     substitutions = kernel.substitutions.copy()
     substitutions[local_subst_name] = local_subst
 
     kernel = kernel.copy(substitutions=substitutions)
 
-    print(kernel)
+    all_precompute_inames = (outer_local_iname, inner_local_iname)
+
+    precompute_inames = pick_out_relevant_axes(all_precompute_inames)
+    sweep_inames = pick_out_relevant_axes((outer_iname, inner_iname))
+
+    precompute_outer_inames = (
+            frozenset(all_precompute_inames)
+            - frozenset(precompute_inames))
 
     from pymbolic import var
-    kernel = lp.precompute(
-            kernel,
+    kernel = lp.precompute(kernel,
             [var(local_subst_name)(var(outer_iname), var(inner_iname))],
-            storage_axes=(outer_iname, inner_iname),
-            sweep_inames=(outer_iname, inner_iname),
-            precompute_inames=(outer_local_iname, inner_local_iname),
+            sweep_inames=sweep_inames,
+            precompute_inames=precompute_inames,
+            storage_axes=local_storage_axes,
+            precompute_outer_inames=precompute_outer_inames,
             temporary_name=local_storage_name,
             compute_insn_id=local_scan_insn_id)
 
-    kernel = _add_scan_subdomain(kernel, inner_scan_iname, inner_local_iname)
-
     # }}}
 
     # {{{ implement local to nonlocal information transfer
 
     from loopy.symbolic import pw_aff_to_expr
     nonlocal_storage_len_pw_aff = (
-            # The 2 here is because the first element is 0.
+            # FIXME: should be 1 + len, bounds check doesnt like this..
             2 + kernel.get_iname_bounds(outer_iname).upper_bound_pw_aff)
 
     nonlocal_storage_len = pw_aff_to_expr(nonlocal_storage_len_pw_aff)
@@ -396,7 +452,7 @@ def make_two_level_scan(
                 TemporaryVariable(
                     nonlocal_storage_name,
                     shape=(nonlocal_storage_len,),
-                    scope=lp.auto,
+                    scope=nonlocal_storage_scope,
                     base_indices=lp.auto,
                     dtype=lp.auto))
 
@@ -412,7 +468,8 @@ def make_two_level_scan(
             id=nonlocal_init_head_insn_id,
             assignees=(var(nonlocal_storage_name)[0],),
             expression=0,
-            within_inames=frozenset([outer_local_iname]),
+            within_inames=frozenset([outer_local_iname,inner_local_iname]),
+            predicates=frozenset([var(inner_local_iname).eq(0)]),
             depends_on=frozenset([local_scan_insn_id]))
 
     final_element_indices = []
@@ -420,11 +477,17 @@ def make_two_level_scan(
     nonlocal_init_tail = make_assignment(
             id=nonlocal_init_tail_insn_id,
             assignees=(var(nonlocal_storage_name)[var(outer_local_iname) + 1],),
-            expression=var(local_storage_name)[var(outer_local_iname),inner_length - 1],
-            within_inames=frozenset([outer_local_iname]),
-            depends_on=frozenset([local_scan_insn_id]))
-
-    kernel = _update_instructions(kernel, (nonlocal_init_head, nonlocal_init_tail), copy=False)
+            expression=var(local_storage_name)[
+                pick_out_relevant_axes(
+                    (var(outer_local_iname),var(inner_local_iname)),
+                    strip_scalar=True)],
+            no_sync_with=frozenset([(local_scan_insn_id, "local")]),
+            within_inames=frozenset([outer_local_iname,inner_local_iname]),
+            depends_on=frozenset([local_scan_insn_id]),
+            predicates=frozenset([var(inner_local_iname).eq(inner_length - 1)]))
+
+    kernel = _update_instructions(
+            kernel, (nonlocal_init_head, nonlocal_init_tail), copy=False)
 
     # }}}
 
@@ -432,6 +495,9 @@ def make_two_level_scan(
 
     kernel.domains.append(_make_slab_set(nonlocal_iname, nonlocal_storage_len))
 
+    if nonlocal_tag is not None:
+        kernel = lp.tag_inames(kernel, {nonlocal_iname: nonlocal_tag})
+
     kernel = _add_scan_subdomain(kernel, outer_scan_iname, nonlocal_iname)
     
     nonlocal_scan = make_assignment(
@@ -446,165 +512,40 @@ def make_two_level_scan(
 
     kernel = _update_instructions(kernel, (nonlocal_scan,), copy=False)
 
-    # }}}
-
-    # {{{ replace scan with local + nonlocal
-
-    updated_insn = insn.copy(
-        depends_on=insn.depends_on | frozenset([nonlocal_scan_insn_id]),
-        expression=var(nonlocal_storage_name)[var(outer_iname)] + var(local_storage_name)[var(outer_iname), var(inner_iname)])
-
-    kernel = _update_instructions(kernel, (updated_insn,), copy=False)
+    if nonlocal_storage_scope == lp.temp_var_scope.GLOBAL:
+        barrier_id = insn_id_gen("barrier_{insn}".format(**format_kwargs))
+        kernel = _add_global_barrier(kernel,
+                source=nonlocal_init_tail_insn_id,
+                sink=nonlocal_scan_insn_id,
+                barrier_id=barrier_id)
 
     # }}}
 
-    return kernel
-
-
-def precompute_scan(
-        kernel, insn_id,
-        sweep_iname,
-        scan_iname,
-        outer_inames=(),
-        temporary_scope=None,
-        temporary_name=None,
-        replace_insn_with_nop=False):
-    """
-    Turn an expression-based scan into an array-based one.
-
-    This takes a reduction of the form::
-
-        [...,sweep_iname] result = reduce(scan_iname, f(scan_iname))
-
-    and does essentially the following transformation::
-
-        [...,sweep_iname'] temp[sweep_iname'] = f(sweep_iname')
-        [...,sweep_iname] temp[sweep_iname] = reduce(scan_iname, temp[scan_iname])
-        [...,sweep_iname] result = temp[sweep_iname]
-
-    Note: this makes an explicit assumption that the sweep iname shares the
-    same bounds as the scan iname and the bounds start at 0.
-    """
-
-    # {{{ sanity checks
-
-    insn = kernel.id_to_insn[insn_id]
-    scan = insn.expression
-    assert scan.inames[0] == scan_iname
-    assert len(scan.inames) == 1
-
-    # }}}
-
-    # {{{ get a stable name for things
-
-    var_name_gen = kernel.get_var_name_generator()
-    insn_id_gen = kernel.get_instruction_id_generator()
-
-    format_kwargs = {"insn": insn_id, "iname": scan_iname}
-
-    orig_subst_name = var_name_gen(
-            "{iname}_orig_subst".format(**format_kwargs))
-
-    scan_subst_name = var_name_gen(
-            "{iname}_subst".format(**format_kwargs))
-
-    precompute_insn = insn_id_gen(
-            "{insn}_precompute".format(**format_kwargs))
-
-    precompute_reduction_insn = insn_id_gen(
-            "{insn}_precompute_reduce".format(**format_kwargs))
-
-    if temporary_name is None:
-        temporary_name = var_name_gen(
-            "{insn}_precompute".format(**format_kwargs))
-
-    # }}}
-
-    from loopy.transform.data import reduction_arg_to_subst_rule
-    kernel = reduction_arg_to_subst_rule(
-            kernel, scan_iname, subst_rule_name=orig_subst_name)
-
-    # {{{ create our own variant of the substitution rule
-
-    # FIXME: There has to be a better way of this.
-
-    orig_subst = kernel.substitutions[orig_subst_name]
-
-    from pymbolic.mapper.substitutor import make_subst_func
-
-    from loopy.symbolic import (
-        SubstitutionRuleMappingContext, RuleAwareSubstitutionMapper)
+    # {{{ replace scan with local + nonlocal
 
-    rule_mapping_context = SubstitutionRuleMappingContext(
-            kernel.substitutions, var_name_gen)
+    updated_depends_on = insn.depends_on | frozenset([nonlocal_scan_insn_id])
 
-    from pymbolic import var
-    mapper = RuleAwareSubstitutionMapper(
-            rule_mapping_context,
-            make_subst_func({scan_iname: var(sweep_iname)}),
-            within=lambda *args: True)
+    if nonlocal_storage_scope == lp.temp_var_scope.GLOBAL:
+        barrier_id = insn_id_gen("barrier_{insn}".format(**format_kwargs))
+        kernel = (_add_global_barrier(kernel,
+                source=nonlocal_scan_insn_id, sink=insn_id, barrier_id=barrier_id))
+        updated_depends_on |= frozenset([barrier_id])
 
-    scan_subst = orig_subst.copy(
-            name=scan_subst_name,
-            arguments=outer_inames + (sweep_iname,),
-            expression=mapper(orig_subst.expression, kernel, None))
+    nonlocal_part = var(nonlocal_storage_name)[var(outer_iname)]
 
-    substitutions = kernel.substitutions.copy()
+    local_part = var(local_storage_name)[
+            pick_out_relevant_axes(
+                (var(outer_iname), var(inner_iname)), strip_scalar=True)]
 
-    substitutions[scan_subst_name] = scan_subst
+    updated_insn = insn.copy(
+            depends_on=updated_depends_on,
+            # XXX: scan binary op
+            expression=nonlocal_part + local_part)
 
-    kernel = kernel.copy(substitutions=substitutions)
+    kernel = _update_instructions(kernel, (updated_insn,), copy=False)
 
     # }}}
 
-    print(kernel)
-
-    # FIXME: multi assignments
-    from pymbolic import var
-
-    # FIXME: Make a new precompute iname....
-
-    kernel = lp.precompute(kernel,
-            [var(scan_subst_name)(
-                *(tuple(var(o) for o in outer_inames) +
-                  (var(sweep_iname),)))],
-            sweep_inames=outer_inames + (sweep_iname,),
-            precompute_inames=(sweep_iname,),
-            temporary_name=temporary_name,
-            temporary_scope=temporary_scope,
-            # FIXME: why on earth is this needed
-            compute_insn_id=precompute_insn)
-
-    from loopy.kernel.instruction import make_assignment
-
-    from loopy.symbolic import Reduction
-    precompute_reduction = insn.copy(
-            id=precompute_reduction_insn,
-            assignee=var(temporary_name)[var(sweep_iname)],
-            expression=Reduction(
-                operation=scan.operation,
-                inames=(scan_iname,),
-                exprs=(var(temporary_name)[var(scan_iname)],),
-                allow_simultaneous=False,
-                ),
-            depends_on=insn.depends_on | frozenset([precompute_insn]))
-
-    kernel = kernel.copy(instructions=kernel.instructions +
-                         [precompute_reduction])
-
-    new_insn = insn.copy(
-           expression=var(temporary_name)[var(sweep_iname)],
-           depends_on=
-           frozenset([precompute_reduction_insn]) | insn.depends_on)
-
-    instructions = list(kernel.instructions)
-
-    for i, insn in enumerate(instructions):
-        if insn.id == insn_id:
-            instructions[i] = new_insn
-
-    kernel = kernel.copy(instructions=instructions)
-
     return kernel
 
 
diff --git a/loopy/transform/save.py b/loopy/transform/save.py
index 8afc1695a..29f4c0238 100644
--- a/loopy/transform/save.py
+++ b/loopy/transform/save.py
@@ -197,16 +197,16 @@ class TemporarySaver(object):
 
             The original temporary variable object.
 
-        .. attribute:: hw_inames
-
-            The common list of hw axes that define the original object.
-
         .. attribute:: hw_dims
 
             A list of expressions, to be added in front of the shape
             of the promoted temporary value, corresponding to
             hardware dimensions
 
+        .. attribute:: hw_tags
+
+            The tags for the inames associated with hw_dims
+
         .. attribute:: non_hw_dims
 
             A list of expressions, to be added in front of the shape
@@ -241,6 +241,75 @@ class TemporarySaver(object):
         self.updated_temporary_variables = {}
         self.saves_or_reloads_added = {}
 
+    def get_hw_axis_sizes_and_tags_for_save_slot(self, temporary):
+        """
+        This is used for determining the amount of global storage needed for saving
+        and restoring the temporary across kernel calls, due to hardware
+        parallel inames (the inferred axes get prefixed to the number of
+        dimensions in the temporary).
+
+        In the case of local temporaries, inames that are tagged
+        hw-local do not contribute to the global storage shape.
+        """
+        accessor_insn_ids = (
+            self.insn_query.insns_reading_or_writing(temporary.name))
+
+        group_tags = None
+        local_tags = None
+
+        def _sortedtags(tags):
+            return sorted(tags, key=lambda tag: tag.axis)
+
+        for insn_id in accessor_insn_ids:
+            insn = self.kernel.id_to_insn[insn_id]
+
+            my_group_tags = []
+            my_local_tags = []
+
+            for iname in insn.within_inames:
+                tag = self.kernel.iname_to_tag[iname]
+
+                from loopy.kernel.data import (
+                    GroupIndexTag, LocalIndexTag, ParallelTag)
+
+                if isinstance(tag, GroupIndexTag):
+                    my_group_tags.append(tag)
+                elif isinstance(tag, LocalIndexTag):
+                    my_local_tags.append(tag)
+                elif isinstance(tag, ParallelTag):
+                    raise ValueError(
+                        "iname '%s' is tagged with '%s' - only "
+                        "local and global tags are supported for "
+                        "auto saving of temporaries" %
+                        (iname, tag))
+
+            if group_tags is None:
+                group_tags = _sortedtags(my_group_tags)
+                local_tags = _sortedtags(my_local_tags)
+
+            if (
+                    group_tags != _sortedtags(my_group_tags)
+                    or local_tags != _sortedtags(my_local_tags)):
+                raise ValueError(
+                    "inconsistent parallel tags across instructions that access '%s'"
+                    % temporary.name)
+
+        if group_tags is None:
+            assert local_tags is None
+            return (), ()
+
+        group_sizes, local_sizes = (
+            self.kernel.get_grid_sizes_for_insn_ids_as_exprs(accessor_insn_ids))
+
+        if temporary.scope == lp.temp_var_scope.LOCAL:
+            # Elide local axes in the save slot for local temporaries.
+            del local_tags[:]
+            local_sizes = ()
+
+        # We set hw_dims to be arranged according to the order:
+        #    g.0 < g.1 < ... < l.0 < l.1 < ...
+        return (group_sizes + local_sizes), tuple(group_tags + local_tags)
+
     @memoize_method
     def auto_promote_temporary(self, temporary_name):
         temporary = self.kernel.temporary_variables[temporary_name]
@@ -259,48 +328,7 @@ class TemporarySaver(object):
             raise ValueError(
                 "Cannot promote temporaries with base_storage to global")
 
-        # `hw_inames`: The set of hw-parallel tagged inames that this temporary
-        # is associated with. This is used for determining the shape of the
-        # global storage needed for saving and restoring the temporary across
-        # kernel calls.
-        #
-        # TODO: Make a policy decision about which dimensions to use. Currently,
-        # the code looks at each instruction that defines or uses the temporary,
-        # and takes the common set of hw-parallel tagged inames associated with
-        # these instructions.
-        #
-        # Furthermore, in the case of local temporaries, inames that are tagged
-        # hw-local do not contribute to the global storage shape.
-        hw_inames = self.insn_query.common_hw_inames(
-            self.insn_query.insns_reading_or_writing(temporary.name))
-
-        # We want hw_inames to be arranged according to the order:
-        #    g.0 < g.1 < ... < l.0 < l.1 < ...
-        # Sorting lexicographically accomplishes this.
-        hw_inames = sorted(hw_inames,
-            key=lambda iname: str(self.kernel.iname_to_tag[iname]))
-
-        # Calculate the sizes of the dimensions that get added in front for
-        # the global storage of the temporary.
-        hw_dims = []
-
-        backing_hw_inames = []
-
-        for iname in hw_inames:
-            tag = self.kernel.iname_to_tag[iname]
-            from loopy.kernel.data import LocalIndexTag
-            is_local_iname = isinstance(tag, LocalIndexTag)
-            if is_local_iname and temporary.scope == temp_var_scope.LOCAL:
-                # Restrict shape to that of group inames for locals.
-                continue
-            backing_hw_inames.append(iname)
-            from loopy.isl_helpers import static_max_of_pw_aff
-            from loopy.symbolic import aff_to_expr
-            hw_dims.append(
-                aff_to_expr(
-                    static_max_of_pw_aff(
-                        self.kernel.get_iname_bounds(iname).size, False)))
-
+        hw_dims, hw_tags = self.get_hw_axis_sizes_and_tags_for_save_slot(temporary)
         non_hw_dims = temporary.shape
 
         if len(non_hw_dims) == 0 and len(hw_dims) == 0:
@@ -310,9 +338,9 @@ class TemporarySaver(object):
         backing_temporary = self.PromotedTemporary(
             name=self.var_name_gen(temporary.name + "_save_slot"),
             orig_temporary=temporary,
-            hw_dims=tuple(hw_dims),
-            non_hw_dims=non_hw_dims,
-            hw_inames=backing_hw_inames)
+            hw_dims=hw_dims,
+            hw_tags=hw_tags,
+            non_hw_dims=non_hw_dims)
 
         return backing_temporary
 
@@ -330,8 +358,7 @@ class TemporarySaver(object):
         dchg = DomainChanger(
             self.kernel,
             frozenset(
-                self.insn_query.inames_in_subkernel(subkernel) |
-                set(promoted_temporary.hw_inames)))
+                self.insn_query.inames_in_subkernel(subkernel)))
 
         domain, hw_inames, dim_inames, iname_to_tag = \
             self.augment_domain_for_save_or_reload(
@@ -342,7 +369,7 @@ class TemporarySaver(object):
         save_or_load_insn_id = self.insn_name_gen(
             "{name}.{mode}".format(name=temporary, mode=mode))
 
-        def subscript_or_var(agg, subscript=()):
+        def add_subscript_if_nonempty(agg, subscript=()):
             from pymbolic.primitives import Subscript, Variable
             if len(subscript) == 0:
                 return Variable(agg)
@@ -354,10 +381,10 @@ class TemporarySaver(object):
         dim_inames_trunc = dim_inames[:len(promoted_temporary.orig_temporary.shape)]
 
         args = (
-            subscript_or_var(
-                temporary, dim_inames_trunc),
-            subscript_or_var(
-                promoted_temporary.name, hw_inames + dim_inames))
+            add_subscript_if_nonempty(
+                temporary, subscript=dim_inames_trunc),
+            add_subscript_if_nonempty(
+                promoted_temporary.name, subscript=hw_inames + dim_inames))
 
         if mode == "save":
             args = reversed(args)
@@ -471,7 +498,9 @@ class TemporarySaver(object):
 
         # Add dimension-dependent inames.
         dim_inames = []
-        domain = domain.add(isl.dim_type.set, len(promoted_temporary.non_hw_dims))
+        domain = domain.add(isl.dim_type.set,
+                            len(promoted_temporary.non_hw_dims)
+                            + len(promoted_temporary.hw_dims))
 
         for dim_idx, dim_size in enumerate(promoted_temporary.non_hw_dims):
             new_iname = self.insn_name_gen("{name}_{mode}_axis_{dim}_{sk}".
@@ -496,22 +525,30 @@ class TemporarySaver(object):
             from loopy.symbolic import aff_from_expr
             domain &= aff[new_iname].lt_set(aff_from_expr(domain.space, dim_size))
 
-        # FIXME: Use promoted_temporary.hw_inames
-        hw_inames = []
+        dim_offset = orig_dim + len(promoted_temporary.non_hw_dims)
 
-        # Add hardware inames duplicates.
-        for t_idx, hw_iname in enumerate(promoted_temporary.hw_inames):
+        hw_inames = []
+        # Add hardware dims.
+        for hw_iname_idx, (hw_tag, dim) in enumerate(
+                zip(promoted_temporary.hw_tags, promoted_temporary.hw_dims)):
             new_iname = self.insn_name_gen("{name}_{mode}_hw_dim_{dim}_{sk}".
                 format(name=orig_temporary.name,
                        mode=mode,
-                       dim=t_idx,
+                       dim=hw_iname_idx,
                        sk=subkernel))
-            hw_inames.append(new_iname)
-            iname_to_tag[new_iname] = self.kernel.iname_to_tag[hw_iname]
+            domain = domain.set_dim_name(
+                isl.dim_type.set, dim_offset + hw_iname_idx, new_iname)
 
-        from loopy.isl_helpers import duplicate_axes
-        domain = duplicate_axes(
-            domain, promoted_temporary.hw_inames, hw_inames)
+            aff = isl.affs_from_space(domain.space)
+            from loopy.symbolic import aff_from_expr
+            domain = (domain
+                &
+                aff[0].le_set(aff[new_iname])
+                &
+                aff[new_iname].lt_set(aff_from_expr(domain.space, dim)))
+
+            self.updated_iname_to_tag[new_iname] = hw_tag
+            hw_inames.append(new_iname)
 
         # The operations on the domain above return a Set object, but the
         # underlying domain should be expressible as a single BasicSet.
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 5e4d013b3..1d1450fc0 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2117,6 +2117,50 @@ def test_barrier_insertion_near_bottom_of_loop():
     assert_barrier_between(knl, "ainit", "aupdate", ignore_barriers_in_levels=[1])
 
 
+def test_global_barrier_order_finding():
+    knl = lp.make_kernel(
+            "{[i,itrip]: 0<=i<n and 0<=itrip<ntrips}",
+            """
+            for i
+                for itrip
+                    ... gbarrier {id=top}
+                    <> z[i] = z[i+1] + z[i]  {id=wr_z,dep=top}
+                    <> v[i] = 11  {id=wr_v,dep=top}
+                    ... gbarrier {dep=wr_z:wr_v,id=yoink}
+                    z[i] = z[i] - z[i+1] + v[i] {id=iupd, dep=yoink}
+                end
+                ... nop {id=nop}
+                ... gbarrier {dep=iupd,id=postloop}
+                z[i] = z[i] - z[i+1] + v[i]  {id=zzzv,dep=postloop}
+            end
+            """)
+
+    assert knl.global_barrier_order == ("top", "yoink", "postloop")
+
+    for insn, barrier in (
+            ("nop", None),
+            ("top", None),
+            ("wr_z", "top"),
+            ("wr_v", "top"),
+            ("yoink", "top"),
+            ("postloop", "yoink"),
+            ("zzzv", "postloop")):
+        assert knl.find_most_recent_global_barrier(insn) == barrier
+
+
+def test_global_barrier_error_if_unordered():
+    # FIXME: Should be illegal to declare this
+    knl = lp.make_kernel("{[i]: 0 <= i < 10}",
+            """
+            ... gbarrier
+            ... gbarrier
+            """)
+
+    from loopy.diagnostic import LoopyError
+    with pytest.raises(LoopyError):
+        knl.global_barrier_order
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
diff --git a/test/test_scan.py b/test/test_scan.py
index aabfe3031..ae046818b 100644
--- a/test/test_scan.py
+++ b/test/test_scan.py
@@ -366,10 +366,10 @@ def test_segmented_scan(ctx_factory, n, segment_boundaries_indices, iname_tag):
 def test_two_level_scan(ctx_getter):
     knl = lp.make_kernel(
         [
-            "{[i,j]: 0 <= i < 256 and 0 <= j <= i}",
+            "{[i,j]: 0 <= i < 16 and 0 <= j <= i}",
         ],
         """
-        out[i] = sum(j, j) {id=scan}
+        out[i] = sum(j, 1) {id=insn}
         """,
         "...")
 
@@ -378,12 +378,36 @@ def test_two_level_scan(ctx_getter):
     from loopy.transform.reduction import make_two_level_scan
 
     knl = make_two_level_scan(
-        knl, "scan", inner_length=128,
+        knl, "insn", inner_length=4,
         scan_iname="j",
-        sweep_iname="i")
+        sweep_iname="i",
+        local_storage_axes=(("l0_inner_update_i",)),
+        inner_iname="l0_inner_update_i",
+        inner_tag="l.0",
+        outer_tag="g.0",
+        local_storage_scope=lp.temp_var_scope.PRIVATE,
+        nonlocal_storage_scope=lp.temp_var_scope.GLOBAL,
+        inner_local_tag="l.0",
+        outer_local_tag="g.0")
+
+    print(knl)
 
     knl = lp.realize_reduction(knl, force_scan=True)
 
+    from loopy.transform.instruction import add_nosync_to_instructions
+    knl = add_nosync_to_instructions(
+            knl,
+            scope="global",
+            source="writes:acc_l0_j",
+            sink="reads:acc_l0_j")
+
+    from loopy.transform.save import save_and_reload_temporaries
+
+    knl = lp.preprocess_kernel(knl)
+    knl = lp.get_one_scheduled_kernel(knl)
+    knl = save_and_reload_temporaries(knl)
+    knl = lp.get_one_scheduled_kernel(knl)
+
     print(knl)
 
     c = ctx_getter()
diff --git a/test/test_transform.py b/test/test_transform.py
index ac5a26f6a..cf2dac48f 100644
--- a/test/test_transform.py
+++ b/test/test_transform.py
@@ -402,6 +402,10 @@ def test_precompute_with_preexisting_inames_fail():
                 precompute_inames="ii,jj")
 
 
+def test_add_nosync_to_instructions():
+    knl = lp.make_kernel("")
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab