From d123027baca1d798c947b84aee08eb5fc78e39c6 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 23 Jul 2011 15:28:56 -0500
Subject: [PATCH] ILP generating, still subtly broken. (ILP 8 bug)

---
 examples/cmdl          |   1 +
 examples/matrix-ops.py |  22 ++-
 loopy/__init__.py      | 377 ++++++++++++++++++++++++++++-------------
 3 files changed, 273 insertions(+), 127 deletions(-)
 create mode 100644 examples/cmdl

diff --git a/examples/cmdl b/examples/cmdl
new file mode 100644
index 000000000..2ea6c30c5
--- /dev/null
+++ b/examples/cmdl
@@ -0,0 +1 @@
+EDITOR=vim PYOPENCL_CTX='1:' COMPUTE_PROFILE=0 python matrix-ops.py 'image_matrix_mul_ilp()'
diff --git a/examples/matrix-ops.py b/examples/matrix-ops.py
index 592cfa1aa..e76acc2e8 100644
--- a/examples/matrix-ops.py
+++ b/examples/matrix-ops.py
@@ -168,7 +168,6 @@ def image_matrix_mul(ctx_factory=cl.create_some_context):
             force_rebuild=True)
 
 
-
 def dg_matrix_mul(ctx_factory=cl.create_some_context):
     dtype = np.float32
     ctx = ctx_factory()
@@ -180,7 +179,7 @@ def dg_matrix_mul(ctx_factory=cl.create_some_context):
     Np_padded = 96
     K = 20000
     dim = 3
-    num_flds = 6
+    num_flds = 2
 
     from pymbolic import var
     fld = var("fld")
@@ -209,8 +208,11 @@ def dg_matrix_mul(ctx_factory=cl.create_some_context):
                 ],
             name="dg_matmul")
 
+    ilp = 4
     knl = lp.split_dimension(knl, "i", 30, 32, outer_tag="g.0", inner_tag="l.0")
-    knl = lp.split_dimension(knl, "k", 16, outer_tag="g.1", inner_tag="l.1")
+    knl = lp.split_dimension(knl, "k", 16*ilp, outer_tag="g.1")
+    knl = lp.split_dimension(knl, "k_inner", 16, outer_tag="ilp", inner_tag="l.1")
+
     assert Np % 2 == 0
     #knl = lp.split_dimension(knl, "j", Np//2)
     #knl = lp.split_dimension(knl, "k", 32)
@@ -218,7 +220,8 @@ def dg_matrix_mul(ctx_factory=cl.create_some_context):
     #for mn in matrix_names:
         #knl = lp.add_prefetch(knl, mn, ["j", "i_inner"])
     for ifld in range(num_flds):
-        knl = lp.add_prefetch(knl, 'fld%d' % ifld, ["k_inner", "j"])
+        knl = lp.add_prefetch(knl, 'fld%d' % ifld,
+                ["k_inner_outer", "k_inner_inner", "j"])
     assert knl.get_invalid_reason() is None
 
     kernel_gen = list(lp.insert_register_prefetches(knl)
@@ -254,7 +257,8 @@ def dg_matrix_mul(ctx_factory=cl.create_some_context):
 
     lp.drive_timing_run(kernel_gen, queue, launcher, num_flds*dim*2*(Np**2)*K,
             options=FAST_OPTIONS + ["-cl-nv-verbose"],
-            force_rebuild=True, edit=True
+            force_rebuild=True, #, edit=True
+            print_code=False
             )
 
 
@@ -286,10 +290,10 @@ def image_matrix_mul_ilp(ctx_factory=cl.create_some_context):
                 ],
             name="matmul")
 
-    ilp = 4
+    ilp = 8
     knl = lp.split_dimension(knl, "i", 16, outer_tag="g.0", inner_tag="l.1")
     knl = lp.split_dimension(knl, "j", ilp*16, outer_tag="g.1")
-    knl = lp.split_dimension(knl, "j_inner", ilp, outer_tag="unr", inner_tag="l.0")
+    knl = lp.split_dimension(knl, "j_inner", 16, outer_tag="ilp", inner_tag="l.0")
     knl = lp.split_dimension(knl, "k", 32)
     # conflict-free
     knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
@@ -316,8 +320,8 @@ def image_matrix_mul_ilp(ctx_factory=cl.create_some_context):
         return evt
 
     lp.drive_timing_run(kernel_gen, queue, launcher, 2*n**3,
-            options=FAST_OPTIONS + ["-cl-nv-verbose"],
-            force_rebuild=True)
+            options=FAST_OPTIONS,# + ["-cl-nv-verbose"],
+            force_rebuild=True, edit_code=False)
 
 
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index aadb8c8b3..570bdd089 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -22,12 +22,23 @@ register_mpz_with_pymbolic()
 
 
 
-# TODO: Fetch dim merging
+# TODO: Reuse of previously split dimensions for prefetch
+#   (Or general merging)
+# TODO: ILP Unroll
+#  ILP dep chain:
+#     Prefetch needs value for k_outer
+#     k_outer is outermost reduction loop
+#     ILP must be outside of reduction loops
+#     Therfore, there are prefetches inside ILPs
+# TODO: Debug 1 ILP
+
 # TODO: Try, fix reg. prefetch (DG example) / CSEs
 # TODO: Custom reductions per red. axis
 # TODO: Functions
 # TODO: Common subexpressions
 # TODO: Parse ops from string
+# FIXME: support non-reductive dimensions
+# FIXME: write names should be assigned during scheduling
 
 # TODO: Condition hoisting
 # TODO: Don't emit spurious barriers (no for scheduled before)
@@ -39,10 +50,8 @@ register_mpz_with_pymbolic()
 # TODO:   - Tricky: Convolution, FD
 # TODO: Try, fix indirect addressing
 # TODO: User controllable switch for slab opt
-# TODO: User control over schedule
 # TODO: Separate all-bulk from non-bulk kernels.
 
-# TODO: ILP Unroll
 
 # TODO: implement efficient div_ceil?
 # TODO: why are corner cases inefficient?
@@ -67,7 +76,10 @@ class IndexTag(Record):
 
 
 
-class AxisParallelTag(IndexTag):
+class ParallelTag(IndexTag):
+    pass
+
+class AxisParallelTag(ParallelTag):
     __slots__ = ["axis", "forced_length"]
 
     def __init__(self, axis, forced_length=None):
@@ -93,9 +105,9 @@ class TAG_GROUP_IDX(AxisParallelTag):
 class TAG_WORK_ITEM_IDX(AxisParallelTag):
     print_name = "WORK_ITEM_IDX"
 
-class TAG_ILP_UNROLL(IndexTag):
+class TAG_ILP(ParallelTag):
     def __repr__(self):
-        return "TAG_ILP_UNROLL"
+        return "TAG_ILP"
 
 class BaseUnrollTag(IndexTag):
     pass
@@ -123,7 +135,7 @@ def parse_tag(tag):
     elif tag == "unri":
         return TAG_UNROLL_INCR()
     elif tag == "ilp":
-        return TAG_ILP_UNROLL
+        return TAG_ILP()
     elif tag.startswith("g."):
         return TAG_GROUP_IDX(int(tag[2:]))
     elif tag.startswith("l."):
@@ -942,7 +954,7 @@ class RegisterPrefetch(Record):
 
 # }}}
 
-def generate_loop_schedules(kernel):
+def generate_loop_schedules(kernel, hints=[]):
     prev_schedule = kernel.schedule
     if prev_schedule is None:
         prev_schedule = [
@@ -957,9 +969,10 @@ def generate_loop_schedules(kernel):
 
     # have a schedulable prefetch? load, schedule it
     had_usable_prefetch = False
-    scheduled_work_item_inames = set(
+    locally_parallel_inames = set(
             iname for iname in scheduled_inames
-            if isinstance(kernel.iname_to_tag.get(iname), TAG_WORK_ITEM_IDX))
+            if isinstance(kernel.iname_to_tag.get(iname), 
+                (TAG_ILP, TAG_WORK_ITEM_IDX)))
 
     for pf in kernel.prefetch.itervalues():
         # already scheduled? never mind then.
@@ -973,7 +986,7 @@ def generate_loop_schedules(kernel):
         # a prefetch variable already scheduled, but not borrowable?
         # (only work item index variables are borrowable)
 
-        if set(pf.inames) & (scheduled_inames - scheduled_work_item_inames):
+        if set(pf.inames) & (scheduled_inames - locally_parallel_inames):
             # dead end: we won't be able to schedule this prefetch
             # in this branch. at least one of its loop dimensions
             # was already scheduled, and that dimension is not
@@ -988,21 +1001,33 @@ def generate_loop_schedules(kernel):
             yield knl
 
     if had_usable_prefetch:
+        # because we've already recursed
         return
 
     # Build set of potentially schedulable variables
     # Don't re-schedule already scheduled variables
     schedulable = kernel.all_inames() - scheduled_inames
 
+    # Schedule in the following order:
+    # - serial output inames
+    # - remaining parallel output inames (i.e. ILP)
+    # - output write
+    # - reduction
     # Don't schedule reduction variables until all output
     # variables are taken care of. Once they are, schedule
     # output writing.
-    serial_output_inames = set(oin for oin in kernel.output_inames()
-            if kernel.iname_to_tag.get(oin) is None)
+    parallel_output_inames = set(oin for oin in kernel.output_inames()
+            if isinstance(kernel.iname_to_tag.get(oin), ParallelTag))
 
-    if not serial_output_inames <= scheduled_inames:
-        schedulable -= kernel.reduction_inames()
-    else:
+    serial_output_inames = kernel.output_inames() - parallel_output_inames
+
+    if schedulable & serial_output_inames:
+        schedulable = schedulable & serial_output_inames
+
+    if schedulable & parallel_output_inames:
+        schedulable  = schedulable & parallel_output_inames
+
+    if kernel.output_inames() <= scheduled_inames:
         if not any(isinstance(sch_item, WriteOutput)
                 for sch_item in prev_schedule):
             kernel = kernel.copy(
@@ -1014,19 +1039,31 @@ def generate_loop_schedules(kernel):
     unsched_prefetch_axes = set(iname
             for pf in kernel.prefetch.itervalues()
             if pf not in prev_schedule
-            for iname in pf.inames)
+            for iname in pf.inames
+            if not isinstance(kernel.iname_to_tag.get(iname), ParallelTag))
     schedulable -= unsched_prefetch_axes
 
+    while hints and hints[0] in scheduled_inames:
+        hints = hints[1:]
+
+    if hints and hints[0] in schedulable:
+        schedulable = set(hints[0])
+
     if schedulable:
         # have a schedulable variable? schedule a loop for it, recurse
         for iname in schedulable:
             new_kernel = kernel.copy(schedule=prev_schedule+[ScheduledLoop(iname=iname)])
-            for knl in generate_loop_schedules(new_kernel):
+            for knl in generate_loop_schedules(new_kernel, hints):
                 yield knl
     else:
         # all loop dimensions and prefetches scheduled?
         # great! yield the finished product if it is complete
 
+        if hints:
+            from warnings import warn
+            warn("leftover schedule hints: "+ (", ".join(hints)),
+                    LoopyAdvisory)
+
         all_inames_scheduled = len(scheduled_inames) == len(kernel.all_inames())
         all_pf_scheduled =  len(set(sch_item for sch_item in prev_schedule
             if isinstance(sch_item, PrefetchDescriptor))) == len(kernel.prefetch)
@@ -1135,7 +1172,7 @@ def gen_code_block(elements):
         ast = Block(block_els)
     return GeneratedCode(ast=ast, num_conditionals=num_conditionals)
 
-def wrap_with(cls, *args):
+def wrap_in(cls, *args):
     inner = args[-1]
     args = args[:-1]
 
@@ -1158,6 +1195,16 @@ def wrap_with(cls, *args):
 
     return GeneratedCode(ast=ast, num_conditionals=num_conditionals)
 
+def wrap_in_if(condition_codelets, inner):
+    from cgen import If
+
+    if condition_codelets:
+        return wrap_in(If,
+                "\n&& ".join(condition_codelets),
+                inner)
+
+    return inner
+
 # }}}
 
 # {{{ C code mapper
@@ -1328,7 +1375,7 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
                     new_impl_domain)
 
             if cur_index+total_realiz_size > dim_length:
-                inner = wrap_with(If,
+                inner = wrap_in(If,
                         "%s < %s" % (ccm(pf_dim_expr), stop_index),
                         inner)
 
@@ -1357,7 +1404,7 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
                 pf_dim_exprs+[pf_dim_expr], pf_idx_subst_map,
                 new_impl_domain)
 
-        return wrap_with(For,
+        return wrap_in(For,
                 "int %s = 0" % pf_dim_var,
                 "%s < %s" % (pf_dim_var, ccm(dim_length)),
                 "++%s" % pf_dim_var,
@@ -1366,7 +1413,9 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
         # }}}
 
 
-def generate_prefetch_code(cgs, kernel, sched_index, implemented_domain):
+def generate_prefetch_code(cgs, kernel, sched_index, exec_domain):
+    implemented_domain = exec_domain.implemented_domain
+
     from cgen import Statement as S, Line, Comment
 
     ccm = cgs.c_code_mapper
@@ -1519,7 +1568,7 @@ def generate_prefetch_code(cgs, kernel, sched_index, implemented_domain):
             "no sync needed"))
 
     new_block.extend([Line(),
-        build_loop_nest(cgs, kernel, sched_index+1, implemented_domain)])
+        build_loop_nest(cgs, kernel, sched_index+1, exec_domain)])
 
     return gen_code_block(new_block)
 
@@ -1527,31 +1576,62 @@ def generate_prefetch_code(cgs, kernel, sched_index, implemented_domain):
 
 # {{{ per-axis loop nest code generation
 
+class ExecutionDomain(object):
+    def __init__(self, implemented_domain, assignments_and_impl_domains=None):
+        """
+        :param implemented_domain: The entire implemented domain,
+            i.e. all constraints that have been enforced so far.
+        :param assignments_and_impl_domains: a list of tuples 
+            (assignments, implemented_domain), where *assignments*
+            is a list of :class:`cgen.Assignment` instances
+            and *implemented_domain* is the implemented domain to which
+            the situation produced by the assignments corresponds.
+
+            The point of this being is a list is the implementation of
+            ILP, and each entry represents a 'fake-parallel' trip through the 
+            ILP'd loop.
+        """
+        if assignments_and_impl_domains is None:
+            assignments_and_impl_domains = [([], implemented_domain)]
+        self.implemented_domain = implemented_domain
+        self.assignments_and_impl_domains = assignments_and_impl_domains
+
+    def __len__(self):
+        return len(self.assignments_and_impl_domains)
+
+    def __iter__(self):
+        return iter(self.assignments_and_impl_domains)
+
+    def intersect(self, set):
+        return ExecutionDomain(
+                self.implemented_domain.intersect(set),
+                [(assignments, implemented_domain.intersect(set))
+                for assignments, implemented_domain
+                in self.assignments_and_impl_domains])
+
+    def get_the_one_domain(self):
+        assert len(self.assignments_and_impl_domains) == 1
+        return self.implemented_domain
+
+
+
+
 def generate_loop_dim_code(cgs, kernel, sched_index,
-        implemented_domain):
+        exec_domain):
     from cgen import (Comment, add_comment, make_multiple_ifs,
-            POD, Initializer, Assign, Line, Const)
+            POD, Assign, Line, Statement as S)
 
     ccm = cgs.c_code_mapper
 
-    space = implemented_domain.get_dim()
+    space = kernel.space
 
     iname = kernel.schedule[sched_index].iname
     tag = kernel.iname_to_tag.get(iname)
 
-    if isinstance(tag, BaseUnrollTag):
-        if 0:
-            # FIXME reactivate?
-            lower, upper, equality = get_bounds_constraints(kernel.domain, iname,
-                    admissible_vars=get_valid_index_vars(kernel, sched_index+1))
-
-            print lower
-            lower_cns, = filter_necessary_constraints(implemented_domain, lower)
-            upper_cns, = filter_necessary_constraints(implemented_domain, upper)
-        else:
-            lower_cns, upper_cns = kernel.get_projected_bounds_constraints(iname)
-            lower_cns = cast_constraint_to_space(lower_cns, space)
-            upper_cns = cast_constraint_to_space(upper_cns, space)
+    if isinstance(tag, (BaseUnrollTag, TAG_ILP)):
+        lower_cns, upper_cns = kernel.get_projected_bounds_constraints(iname)
+        lower_cns = cast_constraint_to_space(lower_cns, space)
+        upper_cns = cast_constraint_to_space(upper_cns, space)
 
         lower_kind, lower_bound = solve_constraint_for_bound(lower_cns, iname)
         upper_kind, upper_bound = solve_constraint_for_bound(upper_cns, iname)
@@ -1564,26 +1644,56 @@ def generate_loop_dim_code(cgs, kernel, sched_index,
         cfm = CommutativeConstantFoldingMapper()
         length = int(cfm(flatten(upper_bound-lower_bound)))
 
-        result = [POD(np.int32, iname), Line()]
-
-        for i in xrange(length):
-            slab = (isl.Set.universe(kernel.space)
-                    .add_constraint(
-                            block_shift_constraint(
-                                lower_cns, iname, -i, as_equality=True)))
-
-            new_impl_domain = implemented_domain.intersect(slab)
-            inner = build_loop_nest(cgs, kernel, sched_index+1,
-                    new_impl_domain)
-
-            if isinstance(tag, TAG_UNROLL_STATIC):
-                result.extend([
-                    Assign(iname, ccm(lower_bound+i)),
-                    Line(), inner])
-            elif isinstance(tag, TAG_UNROLL_INCR):
-                result.append(S("++%s" % iname))
-
-        return gen_code_block(result)
+        def generate_idx_eq_slabs():
+            for i in xrange(length):
+                yield (i, isl.Set.universe(kernel.space)
+                        .add_constraint(
+                                block_shift_constraint(
+                                    lower_cns, iname, -i, as_equality=True)))
+
+        if isinstance(tag, BaseUnrollTag):
+            result = [POD(np.int32, iname), Line()]
+
+            for i, slab in generate_idx_eq_slabs():
+                new_exec_domain = exec_domain.intersect(slab)
+                inner = build_loop_nest(cgs, kernel, sched_index+1,
+                        new_exec_domain)
+
+                if isinstance(tag, TAG_UNROLL_STATIC):
+                    result.extend([
+                        Assign(iname, ccm(lower_bound+i)),
+                        Line(), inner])
+                elif isinstance(tag, TAG_UNROLL_INCR):
+                    result.append(S("++%s" % iname))
+
+            return gen_code_block(result)
+
+        elif isinstance(tag, TAG_ILP):
+            new_aaid = []
+            for assignments, implemented_domain in exec_domain:
+                for i, single_slab in generate_idx_eq_slabs():
+                    assignments = assignments + [
+                            Assign(iname, ccm(lower_bound+i))]
+                    new_aaid.append((assignments, 
+                        implemented_domain.intersect(single_slab)))
+
+                    assignments = []
+
+            overall_slab = (isl.Set.universe(kernel.space)
+                    .add_constraint(lower_cns)
+                    .add_constraint(upper_cns))
+
+            return gen_code_block([
+                Comment("declare ILP'd variable"),
+                POD(np.int32, iname),
+                Line(),
+                build_loop_nest(cgs, kernel, sched_index+1,
+                    ExecutionDomain(
+                        exec_domain.implemented_domain.intersect(overall_slab),
+                        new_aaid))
+                ])
+        else:
+            assert False, "not supposed to get here"
 
     lb_cns_orig, ub_cns_orig = kernel.get_projected_bounds_constraints(iname)
     lb_cns_orig = cast_constraint_to_space(lb_cns_orig, space)
@@ -1608,19 +1718,18 @@ def generate_loop_dim_code(cgs, kernel, sched_index,
             bulk_slab = (isl.Set.universe(kernel.space)
                     .add_constraint(lb_cns)
                     .add_constraint(ub_cns))
-            bulk_impl_domain = implemented_domain.intersect(bulk_slab)
-            if not bulk_impl_domain.is_empty():
-                inner = build_loop_nest(trial_cgs, kernel, sched_index+1,
-                        bulk_impl_domain)
-
-                trials.append((TrialRecord(
-                    lower_incr=lower_incr,
-                    upper_incr=upper_incr,
-                    bulk_slab=bulk_slab),
-                    (inner.num_conditionals,
-                        # when all num_conditionals are equal, choose the
-                        # one with the smallest bounds changes
-                        abs(upper_incr)+abs(lower_incr))))
+            bulk_exec_domain = exec_domain.intersect(bulk_slab)
+            inner = build_loop_nest(trial_cgs, kernel, sched_index+1,
+                    bulk_exec_domain)
+
+            trials.append((TrialRecord(
+                lower_incr=lower_incr,
+                upper_incr=upper_incr,
+                bulk_slab=bulk_slab),
+                (inner.num_conditionals,
+                    # when all num_conditionals are equal, choose the
+                    # one with the smallest bounds changes
+                    abs(upper_incr)+abs(lower_incr))))
 
         from pytools import argmin2
         chosen = argmin2(trials)
@@ -1663,14 +1772,14 @@ def generate_loop_dim_code(cgs, kernel, sched_index,
         # by an if below.
 
         if tag.forced_length is None:
-            implemented_domain = implemented_domain.intersect(
+            exec_domain = exec_domain.intersect(
                     isl.Set.universe(kernel.space)
                     .add_constraint(lb_cns_orig)
                     .add_constraint(ub_cns_orig))
         else:
             impl_len = tag.forced_length
             start, _ = kernel.get_projected_bounds(iname)
-            implemented_domain = implemented_domain.intersect(
+            exec_domain = exec_domain.intersect(
                     make_slab(kernel.space, iname, start, start+impl_len))
 
     result = []
@@ -1681,9 +1790,9 @@ def generate_loop_dim_code(cgs, kernel, sched_index,
         if len(slabs) == 1:
             cmt = None
 
-        new_impl_domain = implemented_domain.intersect(slab)
+        new_exec_domain = exec_domain.intersect(slab)
         inner = build_loop_nest(cgs, kernel, sched_index+1,
-                new_impl_domain)
+                new_exec_domain)
 
         if tag is None:
             # regular loop
@@ -1693,10 +1802,12 @@ def generate_loop_dim_code(cgs, kernel, sched_index,
                     wrap_in_for_from_constraints(ccm, iname, slab, inner))
         else:
             # parallel loop
+            par_impl_domain = exec_domain.get_the_one_domain()
+
             nums_of_conditionals.append(inner.num_conditionals)
             constraint_codelets = generate_bounds_checks(ccm,
                     slab, get_valid_index_vars(kernel, sched_index+1),
-                    implemented_domain)
+                    par_impl_domain)
             result.append(
                     ("\n&& ".join(constraint_codelets),
                         add_comment(cmt, inner.ast)))
@@ -1763,17 +1874,10 @@ def generate_bounds_checks(ccm, domain, valid_index_vars, implemented_domain):
     return [constraint_to_code(ccm, cns) for cns in necessary_constraints]
 
 def wrap_in_bounds_checks(ccm, domain, valid_index_vars, implemented_domain, stmt):
-    from cgen import If
-
-    constraint_codelets = generate_bounds_checks(ccm, domain, valid_index_vars,
-            implemented_domain)
-
-    if constraint_codelets:
-        stmt = wrap_with(If,
-                "\n&& ".join(constraint_codelets),
-                stmt)
-
-    return stmt
+    return wrap_in_if(
+            generate_bounds_checks(ccm, domain, valid_index_vars,
+                implemented_domain),
+            stmt)
 
 def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
     # FIXME add admissible vars
@@ -1817,7 +1921,7 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
     start_expr, = start_exprs # there has to be at least one
 
     from cgen import For
-    return wrap_with(For,
+    return wrap_in(For,
             "int %s = %s" % (iname, start_expr),
             " && ".join(end_conds),
             "++%s" % iname,
@@ -1827,26 +1931,44 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
 
 # {{{ loop nest build top-level dispatch
 
-def build_loop_nest(cgs, kernel, sched_index, implemented_domain):
+def build_loop_nest(cgs, kernel, sched_index, exec_domain):
+    assert isinstance(exec_domain, ExecutionDomain)
+
     ccm = cgs.c_code_mapper
 
-    from cgen import (POD, Initializer, Assign, Statement as S)
+    from cgen import (POD, Initializer, Assign, Statement as S,
+            Line, Block)
 
     if sched_index >= len(kernel.schedule):
         # {{{ write innermost loop body
 
         from pymbolic.primitives import Subscript
 
-        insns = []
+        # FIXME revert to unroll if actual bounds checks are needed?
+
+        valid_index_vars = get_valid_index_vars(kernel, sched_index)
+        bounds_check_lists = [
+                generate_bounds_checks(ccm, kernel.domain,
+                    valid_index_vars, impl_domain)
+                for assignments, impl_domain in
+                    exec_domain]
+
+        result = []
         for lvalue, expr in kernel.instructions:
-            assert isinstance(lvalue, Subscript)
-            name = lvalue.aggregate.name
-            insns.append(S("tmp_%s += %s"
-                % (name, ccm(expr))))
+            for i, (assignments, impl_domain) in \
+                    enumerate(exec_domain):
+
+                result.extend(assignments+[Line()])
 
-        return wrap_in_bounds_checks(ccm, kernel.domain,
-                get_valid_index_vars(kernel, sched_index),
-                implemented_domain, gen_code_block(insns))
+                assert isinstance(lvalue, Subscript)
+                name = lvalue.aggregate.name
+                result.append(
+                        wrap_in_if(
+                            bounds_check_lists[i],
+                            S("tmp_%s_%d += %s"
+                                % (name, i, ccm(expr)))))
+
+        return gen_code_block(result)
 
         # }}}
 
@@ -1854,27 +1976,45 @@ def build_loop_nest(cgs, kernel, sched_index, implemented_domain):
 
     if isinstance(sched_item, ScheduledLoop):
         return generate_loop_dim_code(cgs, kernel, sched_index,
-                implemented_domain)
+                exec_domain)
 
     elif isinstance(sched_item, WriteOutput):
-        return gen_code_block(
+        result = (
                 [Initializer(POD(kernel.arg_dict[lvalue.aggregate.name].dtype,
-                    "tmp_"+lvalue.aggregate.name), 0)
+                    "tmp_%s_%d" % (lvalue.aggregate.name, i)), 0)
+                    for i in range(len(exec_domain))
                     for lvalue, expr in kernel.instructions]
-                +[build_loop_nest(cgs, kernel, sched_index+1, implemented_domain)]+
-                [wrap_in_bounds_checks(ccm, kernel.domain,
-                    get_valid_index_vars(kernel, sched_index),
-                    implemented_domain,
-                    gen_code_block([
-                        Assign(
-                            ccm(lvalue),
-                            "tmp_"+lvalue.aggregate.name)
-                        for lvalue, expr in kernel.instructions]))])
+                +[build_loop_nest(cgs, kernel, sched_index+1, 
+                    exec_domain)])
+
+
+        for i, (idx_assignments, impl_domain) in \
+                enumerate(exec_domain):
+            for lvalue, expr in kernel.instructions:
+                assignment = Assign(ccm(lvalue), "tmp_%s_%d" % (
+                    lvalue.aggregate.name, i))
+
+                wrapped_assign = wrap_in_bounds_checks(
+                        ccm, kernel.domain,
+                        get_valid_index_vars(kernel, sched_index),
+                        impl_domain, assignment)
+
+                result.extend(idx_assignments)
+                result.extend([
+                    Line(),
+                    wrapped_assign,
+                    Line(),
+                    ])
+
+        return gen_code_block(result)
 
     elif isinstance(sched_item, PrefetchDescriptor):
-        return generate_prefetch_code(cgs, kernel, sched_index, implemented_domain)
+        return generate_prefetch_code(cgs, kernel, sched_index, 
+                exec_domain)
 
     elif isinstance(sched_item, RegisterPrefetch):
+        raise NotImplementedError("reg prefetch") # FIXME
+
         agg_name = sched_item.subscript_expr.aggregate.name
         return gen_code_block([
             wrap_in_bounds_checks(ccm, kernel, sched_index, implemented_domain,
@@ -1884,7 +2024,7 @@ def build_loop_nest(cgs, kernel, sched_index, implemented_domain):
                     % (agg_name,
                         ccm(sched_item.subscript_expr.index)))),
 
-            build_loop_nest(cgs, kernel, sched_index+1, implemented_domain)
+            build_loop_nest(cgs, kernel, sched_index+1, exec_domain)
             ])
 
     else:
@@ -2078,7 +2218,8 @@ def generate_code(kernel):
     # }}}
 
     cgs = CodeGenerationState(c_code_mapper=ccm, try_slab_partition=True)
-    gen_code = build_loop_nest(cgs, kernel, 0, isl.Set.universe(kernel.space))
+    gen_code = build_loop_nest(cgs, kernel, 0, 
+            ExecutionDomain(isl.Set.universe(kernel.space)))
     body.extend([Line(), gen_code.ast])
     #print "# conditionals: %d" % gen_code.num_conditionals
 
@@ -2162,7 +2303,7 @@ def add_prefetch(kernel, input_access_descr, tags_or_inames, loc_fetch_axes={}):
 
 class CompiledKernel:
     def __init__(self, context, kernel, size_args=None, options=[],
-            force_rebuild=False, edit=False):
+            force_rebuild=False, edit_code=False):
         self.kernel = kernel
         self.code = generate_code(kernel)
 
@@ -2170,7 +2311,7 @@ class CompiledKernel:
             from time import time
             self.code = "/* %s */\n%s" % (time(), self.code)
 
-        if edit:
+        if edit_code:
             from pytools import invoke_editor
             self.code = invoke_editor(self.code)
 
@@ -2219,7 +2360,7 @@ class CompiledKernel:
 # {{{ timing driver
 def drive_timing_run(kernel_generator, queue, launch, flop_count=None,
         options=[], print_code=True, force_rebuild=False,
-        edit=False):
+        edit_code=False):
 
     def time_run(compiled_knl, warmup_rounds=2, timing_rounds=5):
         check = True
@@ -2244,7 +2385,7 @@ def drive_timing_run(kernel_generator, queue, launch, flop_count=None,
     for kernel in kernel_generator:
 
         compiled = CompiledKernel(queue.context, kernel, options=options,
-                force_rebuild=force_rebuild, edit=edit)
+                force_rebuild=force_rebuild, edit_code=edit_code)
 
         print "-----------------------------------------------"
         print "SOLUTION #%d" % soln_count
-- 
GitLab