From 39f63b41e986a7c2eed1d4600970a9c487a50ef2 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 9 Aug 2011 00:53:52 +0200
Subject: [PATCH] Ensure all work items hit all barriers. Add condition
 hoisting. Tests pass.

---
 loopy/__init__.py         |   1 -
 loopy/codegen/__init__.py |  26 +++-
 loopy/codegen/bounds.py   |  51 ++++++-
 loopy/codegen/dispatch.py |  54 ++++++--
 loopy/codegen/loop_dim.py | 276 ++++++++++++++++++++------------------
 loopy/codegen/prefetch.py |   7 +-
 loopy/kernel.py           |  10 +-
 loopy/schedule.py         |  17 ---
 test/test_matmul.py       |   9 +-
 9 files changed, 272 insertions(+), 179 deletions(-)

diff --git a/loopy/__init__.py b/loopy/__init__.py
index e7ff8c425..f21301d30 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -23,7 +23,6 @@ register_mpz_with_pymbolic()
 # FIXME: support non-reductive dimensions
 # FIXME: write names should be assigned during scheduling
 
-# TODO: Condition hoisting
 # TODO: Don't emit spurious barriers (no for scheduled before)
 # TODO: Make code more readable
 
diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index 67af2676e..e8415bdbb 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -15,24 +15,35 @@ class GeneratedCode(Record):
     """
     __slots__ = ["ast", "num_conditionals"]
 
-def gen_code_block(elements):
+def gen_code_block(elements, is_alternatives=False):
+    """
+    :param is_alternatives: a :class:`bool` indicating that
+        only one of the *elements* will effectively be executed.
+    """
+
     from cgen import Generable, Block
 
-    num_conditionals = 0
+    conditional_counts = []
     block_els = []
     for el in elements:
         if isinstance(el, GeneratedCode):
-            num_conditionals = num_conditionals + el.num_conditionals
+            conditional_counts.append(el.num_conditionals)
             block_els.append(el.ast)
         elif isinstance(el, Generable):
             block_els.append(el)
         else:
             raise ValueError("unidentifiable object in block")
 
+    if is_alternatives:
+        num_conditionals = min(conditional_counts)
+    else:
+        num_conditionals = sum(conditional_counts)
+
     if len(block_els) == 1:
         ast, = block_els
     else:
         ast = Block(block_els)
+
     return GeneratedCode(ast=ast, num_conditionals=num_conditionals)
 
 def wrap_in(cls, *args):
@@ -68,6 +79,15 @@ def wrap_in_if(condition_codelets, inner):
 
     return inner
 
+def add_comment(cmt, code):
+    if cmt is None:
+        return code
+
+    from cgen import add_comment, Block
+    block_with_comment = add_comment(cmt, code.ast)
+    assert isinstance(block_with_comment, Block)
+    return gen_code_block(block_with_comment.contents)
+
 # }}}
 
 # {{{ main code generation entrypoint
diff --git a/loopy/codegen/bounds.py b/loopy/codegen/bounds.py
index b17f035f3..cacea9f15 100644
--- a/loopy/codegen/bounds.py
+++ b/loopy/codegen/bounds.py
@@ -25,11 +25,12 @@ def filter_necessary_constraints(implemented_domain, constraints):
             isl.Set.universe(space)
             .add_constraint(cns))]
 
-def generate_bounds_checks(ccm, domain, check_vars, implemented_domain):
+def generate_bounds_checks(domain, check_vars, implemented_domain):
     domain_bset, = domain.get_basic_sets()
 
     projected_domain_bset = isl.project_out_except(
             domain_bset, check_vars, [dim_type.set])
+    projected_domain_bset = projected_domain_bset.remove_divs()
 
     space = domain.get_dim()
 
@@ -44,15 +45,17 @@ def generate_bounds_checks(ccm, domain, check_vars, implemented_domain):
 
     projected_domain_bset.foreach_constraint(examine_constraint)
 
-    necessary_constraints = filter_necessary_constraints(
+    return filter_necessary_constraints(
             implemented_domain, cast_constraints)
 
-    return [constraint_to_code(ccm, cns) for cns in necessary_constraints]
+def generate_bounds_checks_code(ccm, domain, check_vars, implemented_domain):
+    return [constraint_to_code(ccm, cns) for cns in 
+            generate_bounds_checks(domain, check_vars, implemented_domain)]
 
 def wrap_in_bounds_checks(ccm, domain, check_vars, implemented_domain, stmt):
     from loopy.codegen import wrap_in_if
     return wrap_in_if(
-            generate_bounds_checks(ccm, domain, check_vars,
+            generate_bounds_checks_code(ccm, domain, check_vars,
                 implemented_domain),
             stmt)
 
@@ -110,6 +113,46 @@ def wrap_in_for_from_constraints(ccm, iname, constraint_bset, stmt):
 
 # }}}
 
+def get_valid_check_vars(kernel, sched_index, allow_ilp, exclude_tag_classes=()):
+    """
+    :param exclude_tags: a tuple of tag classes to exclude
+    """
+
+    if not allow_ilp:
+        from loopy.kernel import TAG_ILP
+        exclude_tag_classes = exclude_tag_classes + (TAG_ILP,)
+
+    from loopy.schedule import ScheduledLoop
+    allowed_vars = set(
+            sched_item.iname
+            for sched_item in kernel.schedule[:sched_index]
+            if isinstance(sched_item, ScheduledLoop))
+
+    from pytools import any
+    from loopy.prefetch import LocalMemoryPrefetch
+    all_lmem_prefetches_scheduled = not any(
+            isinstance(sched_item, LocalMemoryPrefetch)
+            for sched_item in kernel.schedule[sched_index:])
+
+    if not all_lmem_prefetches_scheduled:
+        # Lmem prefetches use barriers. Barriers are only allowed if *all* work
+        # items in a work group hit them. Therefore, as long as not all lmem
+        # prefetches are scheduled, we may not check work item indices
+        # (and thereby conceivably mask out some work items).
+
+        from loopy.kernel import TAG_WORK_ITEM_IDX
+        allowed_vars -= set(kernel.inames_by_tag_type(TAG_WORK_ITEM_IDX))
+
+    allowed_vars = set(
+            iname
+            for iname in allowed_vars
+            if not isinstance(
+                kernel.iname_to_tag.get(iname),
+                exclude_tag_classes))
+
+    return allowed_vars
+
+
 
 
 
diff --git a/loopy/codegen/dispatch.py b/loopy/codegen/dispatch.py
index e7ed34c18..f13b501a5 100644
--- a/loopy/codegen/dispatch.py
+++ b/loopy/codegen/dispatch.py
@@ -6,7 +6,7 @@ from loopy.codegen import ExecutionDomain, gen_code_block
 
 
 
-def build_loop_nest(cgs, kernel, sched_index, exec_domain):
+def build_loop_nest(cgs, kernel, sched_index, exec_domain, no_conditional_check=False):
     assert isinstance(exec_domain, ExecutionDomain)
 
     ccm = cgs.c_code_mapper
@@ -14,8 +14,39 @@ def build_loop_nest(cgs, kernel, sched_index, exec_domain):
     from cgen import (POD, Initializer, Assign, Statement as S,
             Line)
 
-    from loopy.schedule import get_valid_index_vars
-    from loopy.codegen.bounds import generate_bounds_checks
+    from loopy.codegen.bounds import (
+            generate_bounds_checks,
+            generate_bounds_checks_code,
+            get_valid_check_vars,
+            constraint_to_code)
+
+    if not no_conditional_check:
+        # {{{ see if there are any applicable conditionals
+
+        applicable_constraints = generate_bounds_checks(
+                kernel.domain,
+                get_valid_check_vars(kernel, sched_index, allow_ilp=False),
+                exec_domain.implemented_domain)
+
+        if applicable_constraints:
+            import islpy as isl
+            exec_domain_restriction = isl.Set.universe(kernel.space)
+            for cns in applicable_constraints:
+                exec_domain_restriction = (exec_domain_restriction
+                        .add_constraint(cns))
+
+            exec_domain = exec_domain.intersect(exec_domain_restriction)
+
+            inner = build_loop_nest(cgs, kernel, sched_index, exec_domain,
+                    no_conditional_check=True)
+
+            from loopy.codegen import wrap_in_if
+            return wrap_in_if([
+                constraint_to_code(ccm, cns)
+                for cns in applicable_constraints],
+                inner)
+
+        # }}}
 
     if sched_index >= len(kernel.schedule):
         # {{{ write innermost loop body
@@ -24,9 +55,9 @@ def build_loop_nest(cgs, kernel, sched_index, exec_domain):
 
         # FIXME revert to unroll if actual bounds checks are needed?
 
-        valid_index_vars = get_valid_index_vars(kernel, sched_index)
+        valid_index_vars = get_valid_check_vars(kernel, sched_index, allow_ilp=True)
         bounds_check_lists = [
-                generate_bounds_checks(ccm, kernel.domain,
+                generate_bounds_checks_code(ccm, kernel.domain,
                     valid_index_vars, impl_domain)
                 for assignments, impl_domain in
                     exec_domain]
@@ -64,15 +95,19 @@ def build_loop_nest(cgs, kernel, sched_index, exec_domain):
     if isinstance(sched_item, ScheduledLoop):
         from loopy.codegen.loop_dim import (
                 generate_unroll_or_ilp_code,
-                generate_non_unroll_loop_dim_code)
-        from loopy.kernel import BaseUnrollTag, TAG_ILP
+                generate_parallel_loop_dim_code,
+                generate_sequential_loop_dim_code)
+        from loopy.kernel import (BaseUnrollTag, TAG_ILP,
+                ParallelTagWithAxis)
 
         tag = kernel.iname_to_tag.get(sched_item.iname)
 
         if isinstance(tag, (BaseUnrollTag, TAG_ILP)):
             func = generate_unroll_or_ilp_code
+        elif isinstance(tag, ParallelTagWithAxis):
+            func = generate_parallel_loop_dim_code
         else:
-            func = generate_non_unroll_loop_dim_code
+            func = generate_sequential_loop_dim_code
 
         return func(cgs, kernel, sched_index, exec_domain)
 
@@ -87,7 +122,6 @@ def build_loop_nest(cgs, kernel, sched_index, exec_domain):
                     exec_domain)]
                 +[Line()])
 
-
         for i, (idx_assignments, impl_domain) in \
                 enumerate(exec_domain):
             for lvalue, expr in kernel.instructions:
@@ -96,7 +130,7 @@ def build_loop_nest(cgs, kernel, sched_index, exec_domain):
 
                 wrapped_assign = wrap_in_bounds_checks(
                         ccm, kernel.domain,
-                        get_valid_index_vars(kernel, sched_index),
+                        get_valid_check_vars(kernel, sched_index, allow_ilp=True),
                         impl_domain, assignment)
 
                 cb = []
diff --git a/loopy/codegen/loop_dim.py b/loopy/codegen/loop_dim.py
index d6d9814b4..503cc6831 100644
--- a/loopy/codegen/loop_dim.py
+++ b/loopy/codegen/loop_dim.py
@@ -10,7 +10,98 @@ from loopy.codegen.dispatch import build_loop_nest
 
 
 
-# {{{ generate code for unrolled/ILP loops
+# {{{ conditional-minimizing slab decomposition
+
+def get_slab_decomposition(cgs, kernel, sched_index, exec_domain):
+    from loopy.isl import (cast_constraint_to_space,
+            block_shift_constraint, negate_constraint)
+
+    ccm = cgs.c_code_mapper
+    space = kernel.space
+    iname = kernel.schedule[sched_index].iname
+    tag = kernel.iname_to_tag.get(iname)
+
+    # {{{ attempt slab partition to reduce conditional count
+
+    lb_cns_orig, ub_cns_orig = kernel.get_projected_bounds_constraints(iname)
+    lb_cns_orig = cast_constraint_to_space(lb_cns_orig, space)
+    ub_cns_orig = cast_constraint_to_space(ub_cns_orig, space)
+
+    # jostle the constant in {lb,ub}_cns to see if we can get
+    # fewer conditionals in the bulk middle segment
+
+    class TrialRecord(Record):
+        pass
+
+    if (cgs.try_slab_partition
+            and "outer" in iname):
+        trial_cgs = cgs.copy(try_slab_partition=False)
+        trials = []
+
+        for lower_incr, upper_incr in [ (0,0), (0,-1), ]:
+
+            lb_cns = block_shift_constraint(lb_cns_orig, iname, -lower_incr)
+            ub_cns = block_shift_constraint(ub_cns_orig, iname, -upper_incr)
+
+            bulk_slab = (isl.Set.universe(kernel.space)
+                    .add_constraint(lb_cns)
+                    .add_constraint(ub_cns))
+            bulk_exec_domain = exec_domain.intersect(bulk_slab)
+            inner = build_loop_nest(trial_cgs, kernel, sched_index+1,
+                    bulk_exec_domain)
+
+            trials.append((TrialRecord(
+                lower_incr=lower_incr,
+                upper_incr=upper_incr,
+                bulk_slab=bulk_slab),
+                (inner.num_conditionals,
+                    # when all num_conditionals are equal, choose the
+                    # one with the smallest bounds changes
+                    abs(upper_incr)+abs(lower_incr))))
+
+        from pytools import argmin2
+        chosen = argmin2(trials)
+    else:
+        bulk_slab = (isl.Set.universe(kernel.space)
+                .add_constraint(lb_cns_orig)
+                .add_constraint(ub_cns_orig))
+        chosen = TrialRecord(
+                    lower_incr=0,
+                    upper_incr=0,
+                    bulk_slab=bulk_slab)
+
+    # }}}
+
+    # {{{ build slabs
+
+    slabs = []
+    if chosen.lower_incr:
+        slabs.append(("initial", isl.Set.universe(kernel.space)
+                .add_constraint(lb_cns_orig)
+                .add_constraint(ub_cns_orig)
+                .add_constraint(
+                    negate_constraint(
+                        block_shift_constraint(
+                            lb_cns_orig, iname, -chosen.lower_incr)))))
+
+    slabs.append(("bulk", chosen.bulk_slab))
+
+    if chosen.upper_incr:
+        slabs.append(("final", isl.Set.universe(kernel.space)
+                .add_constraint(ub_cns_orig)
+                .add_constraint(lb_cns_orig)
+                .add_constraint(
+                    negate_constraint(
+                        block_shift_constraint(
+                            ub_cns_orig, iname, -chosen.upper_incr)))))
+
+    # }}}
+
+    return lb_cns_orig, ub_cns_orig, slabs
+
+# }}}
+
+# {{{ unrolled/ILP loops
 
 def generate_unroll_or_ilp_code(cgs, kernel, sched_index, exec_domain):
     from loopy.isl import (
@@ -88,116 +179,70 @@ def generate_unroll_or_ilp_code(cgs, kernel, sched_index, exec_domain):
 
 # }}}
 
-# {{{ generate code for all other loops
+# {{{ parallel loop
 
-def generate_non_unroll_loop_dim_code(cgs, kernel, sched_index, exec_domain):
-    from loopy.isl import (cast_constraint_to_space,
-            block_shift_constraint, negate_constraint, make_slab)
+def generate_parallel_loop_dim_code(cgs, kernel, sched_index, exec_domain):
+    from loopy.isl import make_slab
 
-    from cgen import (Comment, add_comment, make_multiple_ifs)
 
     ccm = cgs.c_code_mapper
     space = kernel.space
     iname = kernel.schedule[sched_index].iname
     tag = kernel.iname_to_tag.get(iname)
 
-    # {{{ attempt slab partition to reduce conditional count
-
-    lb_cns_orig, ub_cns_orig = kernel.get_projected_bounds_constraints(iname)
-    lb_cns_orig = cast_constraint_to_space(lb_cns_orig, space)
-    ub_cns_orig = cast_constraint_to_space(ub_cns_orig, space)
-
-    # jostle the constant in {lb,ub}_cns to see if we can get
-    # fewer conditionals in the bulk middle segment
-
-    class TrialRecord(Record):
-        pass
+    lb_cns_orig, ub_cns_orig, slabs = get_slab_decomposition(
+            cgs, kernel, sched_index, exec_domain)
 
-    if (cgs.try_slab_partition
-            and "outer" in iname):
-        trial_cgs = cgs.copy(try_slab_partition=False)
-        trials = []
+    # For a parallel loop dimension, the global loop bounds are
+    # automatically obeyed--simply because no work items are launched
+    # outside the requested grid.
+    #
+    # For a forced length, this is implemented by an if below.
 
-        for lower_incr, upper_incr in [ (0,0), (0,-1), ]:
-
-            lb_cns = block_shift_constraint(lb_cns_orig, iname, -lower_incr)
-            ub_cns = block_shift_constraint(ub_cns_orig, iname, -upper_incr)
+    if tag.forced_length is None:
+        exec_domain = exec_domain.intersect(
+                isl.Set.universe(kernel.space)
+                .add_constraint(lb_cns_orig)
+                .add_constraint(ub_cns_orig))
+    else:
+        impl_len = tag.forced_length
+        start, _ = kernel.get_projected_bounds(iname)
+        exec_domain = exec_domain.intersect(
+                make_slab(kernel.space, iname, start, start+impl_len))
 
-            bulk_slab = (isl.Set.universe(kernel.space)
-                    .add_constraint(lb_cns)
-                    .add_constraint(ub_cns))
-            bulk_exec_domain = exec_domain.intersect(bulk_slab)
-            inner = build_loop_nest(trial_cgs, kernel, sched_index+1,
-                    bulk_exec_domain)
+    result = []
+    nums_of_conditionals = []
 
-            trials.append((TrialRecord(
-                lower_incr=lower_incr,
-                upper_incr=upper_incr,
-                bulk_slab=bulk_slab),
-                (inner.num_conditionals,
-                    # when all num_conditionals are equal, choose the
-                    # one with the smallest bounds changes
-                    abs(upper_incr)+abs(lower_incr))))
+    from loopy.codegen import add_comment
 
-        from pytools import argmin2
-        chosen = argmin2(trials)
-    else:
-        bulk_slab = (isl.Set.universe(kernel.space)
-                .add_constraint(lb_cns_orig)
-                .add_constraint(ub_cns_orig))
-        chosen = TrialRecord(
-                    lower_incr=0,
-                    upper_incr=0,
-                    bulk_slab=bulk_slab)
+    for slab_name, slab in slabs:
+        cmt = "%s slab for '%s'" % (slab_name, iname)
+        if len(slabs) == 1:
+            cmt = None
 
-    # }}}
+        new_kernel = kernel.copy(
+                domain=kernel.domain.intersect(slab))
+        result.append(
+                add_comment(cmt,
+                    build_loop_nest(cgs, kernel, sched_index+1,
+                        exec_domain)))
 
-    # {{{ build slabs
+    from loopy.codegen import gen_code_block
+    return gen_code_block(result, is_alternatives=True)
 
-    slabs = []
-    if chosen.lower_incr:
-        slabs.append(("initial", isl.Set.universe(kernel.space)
-                .add_constraint(lb_cns_orig)
-                .add_constraint(ub_cns_orig)
-                .add_constraint(
-                    negate_constraint(
-                        block_shift_constraint(
-                            lb_cns_orig, iname, -chosen.lower_incr)))))
+# }}}
 
-    slabs.append(("bulk", chosen.bulk_slab))
+# {{{ sequential loop
 
-    if chosen.upper_incr:
-        slabs.append(("final", isl.Set.universe(kernel.space)
-                .add_constraint(ub_cns_orig)
-                .add_constraint(lb_cns_orig)
-                .add_constraint(
-                    negate_constraint(
-                        block_shift_constraint(
-                            ub_cns_orig, iname, -chosen.upper_incr)))))
+def generate_sequential_loop_dim_code(cgs, kernel, sched_index, exec_domain):
 
-    # }}}
+    ccm = cgs.c_code_mapper
+    space = kernel.space
+    iname = kernel.schedule[sched_index].iname
+    tag = kernel.iname_to_tag.get(iname)
 
-    # {{{ generate code
-
-    from loopy.kernel import AxisParallelTag
-    if isinstance(tag, AxisParallelTag):
-        # For a parallel loop dimension, the global loop bounds are
-        # automatically obeyed--simply because no work items are launched
-        # outside the requested grid.
-        #
-        # For a forced length, this is actually implemented
-        # by an if below.
-
-        if tag.forced_length is None:
-            exec_domain = exec_domain.intersect(
-                    isl.Set.universe(kernel.space)
-                    .add_constraint(lb_cns_orig)
-                    .add_constraint(ub_cns_orig))
-        else:
-            impl_len = tag.forced_length
-            start, _ = kernel.get_projected_bounds(iname)
-            exec_domain = exec_domain.intersect(
-                    make_slab(kernel.space, iname, start, start+impl_len))
+    lb_cns_orig, ub_cns_orig, slabs = get_slab_decomposition(
+            cgs, kernel, sched_index, exec_domain)
 
     result = []
     nums_of_conditionals = []
@@ -211,49 +256,16 @@ def generate_non_unroll_loop_dim_code(cgs, kernel, sched_index, exec_domain):
         inner = build_loop_nest(cgs, kernel, sched_index+1,
                 new_exec_domain)
 
-        if tag is None:
-            from loopy.codegen.bounds import wrap_in_for_from_constraints
-
-            # regular loop
-            if cmt is not None:
-                result.append(Comment(cmt))
-            result.append(
-                    wrap_in_for_from_constraints(ccm, iname, slab, inner))
-        else:
-            # parallel loop
-            par_impl_domain = exec_domain.get_the_one_domain()
-
-            from loopy.schedule import get_valid_index_vars
-            from loopy.codegen.bounds import generate_bounds_checks
-
-            nums_of_conditionals.append(inner.num_conditionals)
-            constraint_codelets = generate_bounds_checks(ccm,
-                    slab, get_valid_index_vars(kernel, sched_index+1),
-                    par_impl_domain)
-            result.append(
-                    ("\n&& ".join(constraint_codelets),
-                        add_comment(cmt, inner.ast)))
-
-    if tag is None:
-        # regular or unrolled loop
-        return gen_code_block(result)
-
-    elif isinstance(tag, AxisParallelTag):
-        # parallel loop
-        if tag.forced_length is None:
-            base = "last"
-        else:
-            base = None
+        from loopy.codegen.bounds import wrap_in_for_from_constraints
 
-        from loopy.codegen import GeneratedCode
-        return GeneratedCode(
-                ast=make_multiple_ifs(result, base=base),
-                num_conditionals=min(nums_of_conditionals))
+        # regular loop
+        if cmt is not None:
+            from cgen import Comment
+            result.append(Comment(cmt))
+        result.append(
+                wrap_in_for_from_constraints(ccm, iname, slab, inner))
 
-    else:
-        assert False, "we aren't supposed to get here"
-
-    # }}}
+    return gen_code_block(result)
 
 # }}}
 
diff --git a/loopy/codegen/prefetch.py b/loopy/codegen/prefetch.py
index 53bb1766f..d2b2dcdca 100644
--- a/loopy/codegen/prefetch.py
+++ b/loopy/codegen/prefetch.py
@@ -325,9 +325,10 @@ def generate_prefetch_code(cgs, kernel, sched_index, exec_domain):
 
     # {{{ generate fetch code
 
-    from loopy.schedule import get_valid_index_vars
-    valid_index_vars = get_valid_index_vars(kernel, sched_index,
-            exclude_tags=(TAG_WORK_ITEM_IDX,))
+    from loopy.codegen.bounds import get_valid_check_vars
+    valid_index_vars = get_valid_check_vars(kernel, sched_index,
+            allow_ilp=True,
+            exclude_tag_classes=(TAG_WORK_ITEM_IDX,))
 
     from loopy.symbolic import LoopyCCodeMapper
     flnd = FetchLoopNestData(prefetch=pf,
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 277d184f3..21bf43a44 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -92,7 +92,7 @@ class IndexTag(Record):
 class ParallelTag(IndexTag):
     pass
 
-class AxisParallelTag(ParallelTag):
+class ParallelTagWithAxis(ParallelTag):
     __slots__ = ["axis", "forced_length"]
 
     def __init__(self, axis, forced_length=None):
@@ -112,10 +112,10 @@ class AxisParallelTag(ParallelTag):
             return "%s(%d)" % (
                     self.print_name, self.axis)
 
-class TAG_GROUP_IDX(AxisParallelTag):
+class TAG_GROUP_IDX(ParallelTagWithAxis):
     print_name = "GROUP_IDX"
 
-class TAG_WORK_ITEM_IDX(AxisParallelTag):
+class TAG_WORK_ITEM_IDX(ParallelTagWithAxis):
     print_name = "WORK_ITEM_IDX"
 
 class TAG_ILP(ParallelTag):
@@ -355,6 +355,7 @@ class LoopKernel(Record):
         return False
 
     def _subst_prefetch(self, old_var, new_expr):
+        # FIXME delete me
         from pymbolic.mapper.substitutor import substitute
         subst_map = {old_var: new_expr}
 
@@ -373,7 +374,8 @@ class LoopKernel(Record):
     def substitute(self, old_var, new_expr):
         copy = self.copy(instructions=self._subst_insns(old_var, new_expr))
         if self.prefetch:
-            copy.prefetch = self._subst_prefetch(old_var, new_expr)
+            raise RuntimeError("cannot substitute-prefetches already generated")
+            #copy.prefetch = self._subst_prefetch(old_var, new_expr)
 
         if self.schedule is not None:
             raise RuntimeError("cannot substitute-schedule already generated")
diff --git a/loopy/schedule.py b/loopy/schedule.py
index 1966b390b..bda21a58a 100644
--- a/loopy/schedule.py
+++ b/loopy/schedule.py
@@ -5,8 +5,6 @@ from pytools import Record
 
 
 
-# {{{ loop scheduling
-
 # {{{ schedule items
 
 class ScheduledLoop(Record):
@@ -144,21 +142,6 @@ def generate_loop_schedules(kernel, hints=[]):
         if all_inames_scheduled and all_pf_scheduled and output_scheduled:
             yield kernel
 
-# }}}
-
-
-
-
-def get_valid_index_vars(kernel, sched_index, exclude_tags=()):
-    """
-    :param exclude_tags: a tuple of tag classes to exclude
-    """
-    return [
-            sched_item.iname
-            for sched_item in kernel.schedule[:sched_index]
-            if isinstance(sched_item, ScheduledLoop)
-            if not isinstance(kernel.iname_to_tag.get(sched_item.iname), exclude_tags)]
-
 
 
 
diff --git a/test/test_matmul.py b/test/test_matmul.py
index 5d46deae9..2618745cd 100644
--- a/test/test_matmul.py
+++ b/test/test_matmul.py
@@ -205,7 +205,7 @@ def test_image_matrix_mul_ilp(ctx_factory):
                 lp.ImageArg("b", dtype, 2),
                 lp.ArrayArg("c", dtype, shape=(n, n), order=order),
                 ],
-            name="matmul", preamble=DEBUG_PREAMBLE)
+            name="matmul")
 
     ilp = 4
     knl = lp.split_dimension(knl, "i", 2, outer_tag="g.0", inner_tag="l.1")
@@ -214,8 +214,8 @@ def test_image_matrix_mul_ilp(ctx_factory):
     knl = lp.split_dimension(knl, "j_inner", j_inner_split, outer_tag="ilp", inner_tag="l.0")
     knl = lp.split_dimension(knl, "k", 2)
     # conflict-free
-    #knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
-    #knl = lp.add_prefetch(knl, 'b', ["j_inner_outer", "j_inner_inner", "k_inner"])
+    knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
+    knl = lp.add_prefetch(knl, 'b', ["j_inner_outer", "j_inner_inner", "k_inner"])
     assert knl.get_problems()[0] <= 2
 
     kernel_gen = (lp.insert_register_prefetches(knl)
@@ -397,8 +397,7 @@ def test_dg_matrix_mul(ctx_factory):
 
         return evt
 
-    lp.drive_timing_run(kernel_gen, queue, launcher, num_flds*dim*2*(Np**2)*K,
-            edit_code=True)
+    lp.drive_timing_run(kernel_gen, queue, launcher, num_flds*dim*2*(Np**2)*K)
 
 
 
-- 
GitLab