From ed0c18bb02a1aa6a89df72ff0ed78e4c25830379 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Tue, 17 May 2016 04:14:23 -0500
Subject: [PATCH] Add test.

---
 loopy/check.py                   |  27 +++++
 loopy/schedule/__init__.py       |   2 +-
 loopy/schedule/device_mapping.py | 202 +++++++++++++++++--------------
 loopy/target/opencl.py           |   3 +
 test/test_loopy.py               |  41 +++++++
 5 files changed, 186 insertions(+), 89 deletions(-)

diff --git a/loopy/check.py b/loopy/check.py
index 0ef3d27cf..876629fc3 100644
--- a/loopy/check.py
+++ b/loopy/check.py
@@ -342,6 +342,32 @@ def check_write_destinations(kernel):
                     or wvar in kernel.arg_dict) and wvar not in kernel.all_params():
                 raise LoopyError
 
+
+def check_that_temporaries_are_well_defined_in_hw_axes(kernel):
+    from loopy.schedule.device_mapping import (
+        get_common_hw_inames, get_def_and_use_lists_for_all_temporaries,
+        get_hw_inames)
+
+    def_lists, use_lists = get_def_and_use_lists_for_all_temporaries(kernel)
+
+    for temporary in sorted(def_lists):
+        def_list = def_lists[temporary]
+        hw_inames = get_common_hw_inames(kernel, def_list)
+
+        # Ensure that no use of the temporary is at a loop nesting level
+        # that is "more general" than the definition.
+        for use in use_lists[temporary]:
+            if not hw_inames <= get_hw_inames(kernel, use):
+                raise LoopyError(
+                    "Temporary variable `{temporary}` gets used in a more "
+                    "general hardware parallel loop than it is defined. "
+                    "(used by instruction id `{id}`, inames: {use_inames}) "
+                    "(defined in inames: {def_inames}).".format(
+                        temporary=temporary,
+                        id=use.id,
+                        use_inames=", ".join(sorted(get_hw_inames(insn))),
+                        def_inames=", ".join(sorted(hw_inames))))
+
 # }}}
 
 
@@ -359,6 +385,7 @@ def pre_schedule_checks(kernel):
         check_for_data_dependent_parallel_bounds(kernel)
         check_bounds(kernel)
         check_write_destinations(kernel)
+        check_that_temporaries_are_well_defined_in_hw_axes(kernel)
 
         logger.info("pre-schedule check %s: done" % kernel.name)
     except KeyboardInterrupt:
diff --git a/loopy/schedule/__init__.py b/loopy/schedule/__init__.py
index d62eb04a9..132bf01e3 100644
--- a/loopy/schedule/__init__.py
+++ b/loopy/schedule/__init__.py
@@ -1506,7 +1506,7 @@ def generate_loop_schedules(kernel, debug_args={}):
                         schedule=gen_sched,
                         state=kernel_state.SCHEDULED)
 
-                from loopy.codegen.device_mapping import \
+                from loopy.schedule.device_mapping import \
                         map_schedule_onto_host_or_device
                 new_kernel = map_schedule_onto_host_or_device(new_kernel)
                 yield new_kernel
diff --git a/loopy/schedule/device_mapping.py b/loopy/schedule/device_mapping.py
index 51ac79432..481344d26 100644
--- a/loopy/schedule/device_mapping.py
+++ b/loopy/schedule/device_mapping.py
@@ -22,17 +22,15 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from pytools import Record
 from loopy.diagnostic import LoopyError
 
 
-def postprocess(kernel, global_barrier_splitting=False):
-    # Analyze the kernel to determine if temporaries are used in a sane way.
-    # TODO: Should probably be done in a pre codegen check.
-    check_temporary_sanity(kernel)
+def map_schedule_onto_host_or_device(kernel):
+    # Split the schedule onto host or device.
+    kernel = map_schedule_onto_host_or_device_impl(kernel)
 
-    if not global_barrier_splitting:
-        from loopy.schedule import CallKernel
+    if not kernel.target.split_kernel_at_global_barriers():
+        from loopy.schedule import CallKernel, ReturnFromKernel
         new_schedule = (
             [CallKernel(kernel_name=kernel.name,
                         extra_inames=[],
@@ -40,29 +38,45 @@ def postprocess(kernel, global_barrier_splitting=False):
             kernel.schedule +
             [ReturnFromKernel(kernel_name=kernel.name)])
         return kernel.copy(schedule=new_schedule)
-    # Split the schedule onto host or device.
-    kernel = map_schedule_onto_host_or_device(kernel)
-    # Compute which temporaries and inames go into which kernel
-    kernel = save_and_restore_temporaries(kernel)
+
+    # Compute which temporaries and inames go into which kernel.
+    kernel = restore_and_save_temporaries(kernel)
     return kernel
 
 
 def get_block_boundaries(schedule):
-    from loopy.schedule import (
-        EnterLoop, LeaveLoop, CallKernel, ReturnFromKernel)
+    """
+    Return a dictionary mapping BlockBeginItems to BlockEndItems and vice
+    versa.
+    """
+    from loopy.schedule import (BeginBlockItem, EndBlockItem)
     block_bounds = {}
     active_blocks = []
     for idx, sched_item in enumerate(schedule):
-        if isinstance(sched_item, (EnterLoop, CallKernel)):
+        if isinstance(sched_item, BeginBlockItem):
             active_blocks.append(idx)
-        elif isinstance(sched_item, (LeaveLoop, ReturnFromKernel)):
+        elif isinstance(sched_item, EndBlockItem):
             start = active_blocks.pop()
             block_bounds[start] = idx
             block_bounds[idx] = start
     return block_bounds
 
 
+def get_hw_inames(kernel, insn):
+    """
+    Return the inames that insn runs in and that are tagged as hardware
+    parallel.
+    """
+    from loopy.kernel.data import HardwareParallelTag
+    return set(iname for iname in kernel.insn_inames(insn)
+        if isinstance(kernel.iname_to_tag.get(iname), HardwareParallelTag))
+
+
 def get_common_hw_inames(kernel, insn_ids):
+    """
+    Return the common set of hardware parallel tagged inames among
+    the list of instructions.
+    """
     # Get the list of hardware inames in which the temporary is defined.
     if len(insn_ids) == 0:
         return set()
@@ -77,6 +91,9 @@ def get_common_hw_inames(kernel, insn_ids):
 
 
 def filter_out_subscripts(exprs):
+    """
+    Remove subscripts from expressions in `exprs`.
+    """
     result = set()
     from pymbolic.primitives import Subscript
     for expr in exprs:
@@ -87,6 +104,9 @@ def filter_out_subscripts(exprs):
 
 
 def filter_temporaries(kernel, items):
+    """
+    Keep only the values in `items` which are temporaries.
+    """
     from pymbolic.primitives import Subscript, Variable
     result = set()
     for item in items:
@@ -101,6 +121,9 @@ def filter_temporaries(kernel, items):
 
 
 def get_use_set(insn, include_subscripts=True):
+    """
+    Return the use-set of the instruction, for liveness analysis.
+    """
     result = insn.read_dependency_names()
     if not include_subscripts:
         result = filter_out_subscripts(result)
@@ -108,6 +131,9 @@ def get_use_set(insn, include_subscripts=True):
 
 
 def get_def_set(insn, include_subscripts=True):
+    """
+    Return the def-set of the instruction, for liveness analysis.
+    """
     result = insn.write_dependency_names()
     if not include_subscripts:
         result = filter_out_subscripts(result)
@@ -115,11 +141,13 @@ def get_def_set(insn, include_subscripts=True):
 
 
 def get_def_and_use_lists_for_all_temporaries(kernel):
+    """
+    Return a pair `def_lists`, `use_lists` which map temporary variable
+    names to lists of instructions where they are defined or used.
+    """
     def_lists = dict((t, []) for t in kernel.temporary_variables)
     use_lists = dict((t, []) for t in kernel.temporary_variables)
 
-    # {{{ Gather use-def information
-
     for insn in kernel.instructions:
         assignees = get_def_set(insn, include_subscripts=False)
         dependencies = get_use_set(insn, include_subscripts=False)
@@ -138,29 +166,28 @@ def get_def_and_use_lists_for_all_temporaries(kernel):
             if dep in kernel.temporary_variables:
                 use_lists[dep].append(insn.id)
 
-    # }}}
-
     return def_lists, use_lists
 
 # }}}
 
 
 def compute_live_temporaries(kernel, schedule):
+    """
+    Compute live-in and live-out sets for temporary variables.
+    """
     live_in = [set() for i in range(len(schedule) + 1)]
     live_out = [set() for i in range(len(schedule))]
 
     id_to_insn = kernel.id_to_insn
     block_bounds = get_block_boundaries(schedule)
 
-    # {{{ Liveness analysis
+    # {{{ Liveness analysis implementation
 
     from loopy.schedule import (
-        EnterLoop, LeaveLoop, CallKernel, ReturnFromKernel, Barrier, RunInstruction)
-
+        LeaveLoop, ReturnFromKernel, Barrier, RunInstruction)
 
     def compute_subrange_liveness(start_idx, end_idx):
         idx = end_idx
-
         while start_idx <= idx:
             sched_item = schedule[idx]
             if isinstance(sched_item, LeaveLoop):
@@ -203,9 +230,9 @@ def compute_live_temporaries(kernel, schedule):
             elif isinstance(sched_item, Barrier):
                 live_in[idx] = live_out[idx] = live_in[idx + 1]
                 idx -= 1
-
             else:
-                raise ValueError()
+                raise LoopyError("unexepcted type of schedule item: %s"
+                        % type(sched_item).__name__)
 
     # }}}
 
@@ -228,7 +255,11 @@ def compute_live_temporaries(kernel, schedule):
     return live_in, live_out
 
 
-def save_and_restore_temporaries(kernel):
+def restore_and_save_temporaries(kernel):
+    """
+    Add code that loads / spills the temporaries in the kernel which are
+    live across sub-kernel calls.
+    """
     # Compute live temporaries.
     live_in, live_out = compute_live_temporaries(kernel, kernel.schedule)
 
@@ -252,14 +283,27 @@ def save_and_restore_temporaries(kernel):
     class PromotedTemporary(Record):
         """
         .. attribute:: name
+
+            The name of the new temporary.
+
         .. attribute:: orig_temporary
+
+            The original temporary variable object.
+
         .. attribute:: hw_inames
+
+            The common list of hw axes that define the original object.
+
         .. attribute:: shape_prefix
+
+            A list of expressions, to be added in front of the shape
+            of the promoted temporary value
         """
 
         def as_variable(self):
-            from loopy.kernel.data import TemporaryVariable
             temporary = self.orig_temporary
+            from loopy.kernel.data import TemporaryVariable
+            # XXX: This needs to be marked as global.
             return TemporaryVariable(
                 name=self.name,
                 dtype=temporary.dtype,
@@ -269,7 +313,6 @@ def save_and_restore_temporaries(kernel):
         def new_shape(self):
             return self.shape_prefix + self.orig_temporary.shape
 
-
     for temporary in inter_kernel_temporaries:
         from loopy.kernel.data import LocalIndexTag, temp_var_scope
 
@@ -302,7 +345,6 @@ def save_and_restore_temporaries(kernel):
                     static_max_of_pw_aff(
                         kernel.get_iname_bounds(iname).size, False)))
 
-        from loopy.kernel.data import TemporaryVariable
         backing_temporary = PromotedTemporary(
             name=name_gen(temporary.name),
             orig_temporary=temporary,
@@ -351,27 +393,23 @@ def save_and_restore_temporaries(kernel):
                         kernel, get_use_set(insn)))
             idx += 1
 
-        items_to_spill = subkernel_defs & live_out[idx]
-        # Need to load items_to_spill, to avoid overwriting entries that the
+        tvals_to_spill = subkernel_defs & live_out[idx]
+        # Need to load tvals_to_spill, to avoid overwriting entries that the
         # code doesn't touch when doing the spill.
-        items_to_load = (subkernel_uses | items_to_spill) & live_in[start_idx]
+        tvals_to_load = (subkernel_uses | tvals_to_spill) & live_in[start_idx]
 
         # Add arguments.
         new_schedule.append(
             sched_item.copy(extra_args=sorted(
-                set(new_temporaries[item].name
-                    for item in items_to_spill | items_to_load))))
-
-        from loopy.kernel.tools import DomainChanger
-        dchg = DomainChanger(kernel, frozenset(sched_item.extra_inames))
-        domain = dchg.get_original_domain()
+                set(new_temporaries[tval].name
+                    for tval in tvals_to_spill | tvals_to_load))))
 
         import islpy as isl
 
         # {{{ Add all the loads and spills.
 
-        def augment_domain(item, domain, mode_str):
-            temporary = new_temporaries[item]
+        def augment_domain(tval, domain, mode_str):
+            temporary = new_temporaries[tval]
             orig_size = domain.dim(isl.dim_type.set)
             dims_to_insert = len(temporary.orig_temporary.shape)
             # Add dimension-dependent inames.
@@ -390,10 +428,10 @@ def save_and_restore_temporaries(kernel):
                 dim_inames.append(new_iname)
                 # Add size information.
                 aff = isl.affs_from_space(domain.space)
-                domain &= aff[0].lt_set(aff[iname])
+                domain &= aff[0].le_set(aff[iname])
                 size = temporary.orig_temporary.shape[t_idx]
                 from loopy.symbolic import aff_from_expr
-                domain &= aff[iname].lt_set(aff_from_expr(domain.space, size))
+                domain &= aff[iname].le_set(aff_from_expr(domain.space, size))
 
             hw_inames = []
 
@@ -427,33 +465,55 @@ def save_and_restore_temporaries(kernel):
                     tuple(map(Variable, subscript)))
 
         from loopy.kernel.data import Assignment
-        for item in items_to_load:
-            domain, hw_inames, dim_inames = augment_domain(item, domain, "load")
+        # After loading local temporaries, we need to insert a barrier.
+        needs_local_barrier = False
+
+        from loopy.kernel.tools import DomainChanger
+        for tval in tvals_to_load:
+            tval_hw_inames = new_temporaries[tval].hw_inames
+            dchg = DomainChanger(kernel,
+                frozenset(sched_item.extra_inames + tval_hw_inames))
+            domain = dchg.domain
+
+            domain, hw_inames, dim_inames = augment_domain(tval, domain, "load")
+            kernel = dchg.get_kernel_with(domain)
 
             # Add a load instruction.
-            insn_id = name_gen("{name}.load".format(name=item))
+            insn_id = name_gen("{name}.load".format(name=tval))
 
             new_insn = Assignment(
                 subscript_or_var(
-                    item, dim_inames),
+                    tval, dim_inames),
                 subscript_or_var(
-                    new_temporaries[item].name, hw_inames + dim_inames),
+                    new_temporaries[tval].name, hw_inames + dim_inames),
                 id=insn_id)
 
             new_instructions.append(new_insn)
             subkernel_prolog.append(RunInstruction(insn_id=insn_id))
+            if new_temporaries[tval].orig_temporary.is_local:
+                needs_local_barrier = True
 
-        for item in items_to_spill:
-            domain, hw_inames, dim_inames = augment_domain(item, domain, "spill")
+        if needs_local_barrier:
+            from loopy.schedule import Barrier
+            subkernel_prolog.append(Barrier(kind="local"))
 
-            # Add a load instruction.
-            insn_id = name_gen("{name}.spill".format(name=item))
+        for tval in tvals_to_spill:
+            tval_hw_inames = new_temporaries[tval].hw_inames
+            dchg = DomainChanger(kernel,
+                frozenset(sched_item.extra_inames + tval_hw_inames))
+            domain = dchg.domain
+
+            domain, hw_inames, dim_inames = augment_domain(tval, domain, "spill")
+            kernel = dchg.get_kernel_with(domain)
+
+            # Add a spill instruction.
+            insn_id = name_gen("{name}.spill".format(name=tval))
 
             new_insn = Assignment(
                 subscript_or_var(
-                    new_temporaries[item].name, hw_inames + dim_inames),
+                    new_temporaries[tval].name, hw_inames + dim_inames),
                 subscript_or_var(
-                    item, dim_inames),
+                    tval, dim_inames),
                 id=insn_id)
 
             new_instructions.append(new_insn)
@@ -461,10 +521,6 @@ def save_and_restore_temporaries(kernel):
 
         # }}}
 
-        # DomainChanger returns a new kernel object, so we need to replace the
-        # kernel here.
-        kernel = dchg.get_kernel_with(domain)
-
         new_schedule.extend(
             subkernel_prolog +
             subkernel_schedule +
@@ -491,7 +547,7 @@ def save_and_restore_temporaries(kernel):
     return kernel
 
 
-def map_schedule_onto_host_or_device(kernel):
+def map_schedule_onto_host_or_device_impl(kernel):
     from loopy.schedule import (
         RunInstruction, EnterLoop, LeaveLoop, Barrier,
         CallKernel, ReturnFromKernel)
@@ -608,34 +664,4 @@ def map_schedule_onto_host_or_device(kernel):
     return new_kernel
 
 
-def get_hw_inames(kernel, insn):
-    from loopy.kernel.data import HardwareParallelTag
-    return set(iname for iname in kernel.insn_inames(insn)
-        if isinstance(kernel.iname_to_tag.get(iname), HardwareParallelTag))
-
-
-def analyze_temporaries(kernel):
-    # {{{ Analyze uses of temporaries by hardware loops
-
-    def_lists, use_lists = get_def_and_use_lists_for_all_temporaries(kernel)
-
-    for temporary in sorted(def_lists):
-        def_list = def_lists[temporary]
-
-        # Ensure that no use of the temporary is at a loop nesting level
-        # that is "more general" than the definition.
-        for use in use_lists[temporary]:
-            if not hw_inames <= get_hw_inames(insn):
-                raise ValueError(
-                    "Temporary variable `{temporary}` gets used in a more "
-                    "general hardware parallel loop than it is defined. "
-                    "(used by instruction id `{id}`, inames: {use_inames}) "
-                    "(defined in inames: {def_inames}).".format(
-                        temporary=temporary,
-                        id=use.id,
-                        use_inames=", ".join(sorted(get_hw_inames(insn))),
-                        def_inames=", ".join(sorted(hw_inames))))
-
-    # }}}
-
 # }}}
diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py
index e06696614..c8f3b6b9e 100644
--- a/loopy/target/opencl.py
+++ b/loopy/target/opencl.py
@@ -317,6 +317,9 @@ class OpenCLTarget(CTarget):
 
         self.atomics_flavor = atomics_flavor
 
+    def split_kernel_at_global_barriers(self):
+        return True
+
     def get_device_ast_builder(self):
         return OpenCLCASTBuilder(self)
 
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 588cc9a20..fb19343c1 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2639,6 +2639,47 @@ def test_kernel_splitting_with_loop(ctx_factory):
     #lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5))
 
 
+def test_kernel_splitting_with_loop_and_temporaries(ctx_factory):
+    #ctx = ctx_factory()
+
+    knl = lp.make_kernel(
+            "{ [i,k]: 0<=i<n and 0<=k<3 }",
+            """
+            <> t_extra_dim[i,0,i] = i
+            <> t_private = a[k,i+1]
+            <> t_local[k,i] = a[k,i+1]
+            c[k,i] = a[k,i+1] + t_extra_dim[i,0,i]
+            out[k,i] = c[k,i] + t_private + t_local[k,i]
+            """)
+
+    knl = lp.add_and_infer_dtypes(knl,
+            {"a": np.float32, "c": np.float32, "out": np.float32, "n": np.int32})
+
+    ref_knl = knl
+
+    knl = lp.split_iname(knl, "i", 128, outer_tag="g.0", inner_tag="l.0")
+
+    # schedule
+    from loopy.preprocess import preprocess_kernel
+    knl = preprocess_kernel(knl)
+
+    from loopy.schedule import get_one_scheduled_kernel
+    knl = get_one_scheduled_kernel(knl)
+
+    # map schedule onto host or device
+    print(knl)
+
+    cgr = lp.generate_code_v2(knl)
+
+    assert len(cgr.device_programs) == 3
+
+    print(cgr.device_code())
+    print(cgr.host_code())
+
+    # Doesn't yet work--not passing k, temporaries
+    #lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=5))
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab