From 6ae4c97faba6c8dbc04fcbad6d21c28bba43e251 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 14 Aug 2011 00:08:15 +0200
Subject: [PATCH] Fetch dimension merging. Code readability. No spurious
 barriers.

---
 loopy/__init__.py         |  30 +++--
 loopy/codegen/__init__.py |   7 +-
 loopy/codegen/prefetch.py | 262 ++++++++++++++++++++++++++------------
 loopy/prefetch.py         |  40 ++++--
 loopy/schedule.py         |   4 +-
 loopy/symbolic.py         |   2 +-
 test/test_matmul.py       |   7 +-
 7 files changed, 241 insertions(+), 111 deletions(-)

diff --git a/loopy/__init__.py b/loopy/__init__.py
index 6cba9ccd3..64f426dc9 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -11,30 +11,25 @@ register_mpz_with_pymbolic()
 
 
 
-# TODO: Reuse of previously split dimensions for prefetch
-#   (Or general merging)
-
 # TODO: Try, fix reg. prefetch (DG example) / CSEs
+#   ILP and reg. prefetch (might) interact!
 # TODO: Custom reductions per red. axis
 # TODO: Functions
 # TODO: Common subexpressions
 # TODO: Parse ops from string
-# FIXME: support non-reductive dimensions
+# FIXME: support non-reductive dimensions (what did I mean here?)
 # FIXME: write names should be assigned during scheduling
 
-# TODO: Don't emit spurious barriers (no for scheduled before)
-# TODO: Make code more readable
-
 # TODO: Divisibility
 # TODO: Try different kernels
 # TODO:   - Tricky: Convolution, FD
 # TODO: Try, fix indirect addressing
-# TODO: User controllable switch for slab opt
+# TODO: More user control for slab opt
 # TODO: Separate all-bulk from non-bulk kernels. (maybe?) (#ifdef?)
 
-# TODO: implement efficient div_ceil? (as opposed to floor_div)
+# TODO: implement efficient ceil_div? (as opposed to floor_div)
 # TODO: why are corner cases inefficient?
-# TODO: Use gists
+# TODO: Use gists (why do disjoint sets arise?)
 # TODO: Imitate codegen bulk slab handling in bulk slab trials
 
 
@@ -78,12 +73,13 @@ def get_input_access_descriptors(kernel):
 
     return result
 
-def add_prefetch(kernel, input_access_descr, tags_or_inames, loc_fetch_axes={}):
+def add_prefetch(kernel, input_access_descr, fetch_dims, loc_fetch_axes={}):
     """
     :arg input_access_descr: see :func:`get_input_access_descriptors`.
         May also be the name of the variable if there is only one
         reference to that variable.
-    :arg tags_or_inames: loop dimensions that are used to carry out the prefetch
+    :arg fetch_dims: loop dimensions indexing the input variable on which
+        the prefetch is to be carried out.
     """
 
     if isinstance(input_access_descr, str):
@@ -96,7 +92,13 @@ def add_prefetch(kernel, input_access_descr, tags_or_inames, loc_fetch_axes={}):
 
         input_access_descr, = var_iads
 
-    inames = [kernel.tag_or_iname_to_iname(s) for s in tags_or_inames]
+    def parse_fetch_dim(iname):
+        if isinstance(iname, str):
+            iname = (iname,)
+
+        return tuple(kernel.tag_or_iname_to_iname(s) for s in iname)
+
+    fetch_dims = [parse_fetch_dim(fd) for fd in fetch_dims]
     ivec, iexpr = input_access_descr
 
     new_prefetch = getattr(kernel, "prefetch", {}).copy()
@@ -109,7 +111,7 @@ def add_prefetch(kernel, input_access_descr, tags_or_inames, loc_fetch_axes={}):
             kernel=kernel,
             input_vector=ivec,
             index_expr=iexpr,
-            inames=inames,
+            fetch_dims=fetch_dims,
             loc_fetch_axes={})
 
     return kernel.copy(prefetch=new_prefetch)
diff --git a/loopy/codegen/__init__.py b/loopy/codegen/__init__.py
index d0e4cf4a3..e8f10a1a8 100644
--- a/loopy/codegen/__init__.py
+++ b/loopy/codegen/__init__.py
@@ -15,7 +15,7 @@ class GeneratedCode(Record):
     """
     __slots__ = ["ast", "num_conditionals"]
 
-def gen_code_block(elements, is_alternatives=False):
+def gen_code_block(elements, is_alternatives=False, denest=False):
     """
     :param is_alternatives: a :class:`bool` indicating that
         only one of the *elements* will effectively be executed.
@@ -28,7 +28,10 @@ def gen_code_block(elements, is_alternatives=False):
     for el in elements:
         if isinstance(el, GeneratedCode):
             conditional_counts.append(el.num_conditionals)
-            block_els.append(el.ast)
+            if isinstance(el.ast, Block) and denest:
+                block_els.extend(el.ast.contents)
+            else:
+                block_els.append(el.ast)
         elif isinstance(el, Generable):
             block_els.append(el)
         else:
diff --git a/loopy/codegen/prefetch.py b/loopy/codegen/prefetch.py
index 946f23cc9..e7d880cf2 100644
--- a/loopy/codegen/prefetch.py
+++ b/loopy/codegen/prefetch.py
@@ -4,6 +4,9 @@ from pytools import Record
 import pyopencl as cl
 import pyopencl.characterize as cl_char
 from loopy.codegen import wrap_in, gen_code_block
+import islpy as isl
+from islpy import dim_type
+import numpy as np
 
 
 
@@ -22,8 +25,8 @@ def preprocess_prefetch(kernel):
         all_pf_nbytes = [opf.nbytes for opf in all_pf_list]
         other_pf_sizes = sum(all_pf_nbytes[:i_pf]+all_pf_nbytes[i_pf+1:])
 
-        shape = [stop-start for start, stop in pf.dim_bounds]
-        dim_storage_lengths = shape[:]
+        dim_storage_lengths = [stop-start for start, stop in
+                [pf.dim_bounds_by_iname[iname] for iname in pf.all_inames()]]
 
         # sizes of all dims except the last one, which we may change
         # below to avoid bank conflicts
@@ -47,7 +50,7 @@ def preprocess_prefetch(kernel):
                 test_dsl[-1] = test_dsl[-1] + increment
                 new_mult, why_not = cl_char.why_not_local_access_conflict_free(
                         kernel.device, pf.itemsize,
-                        shape, test_dsl)
+                        pf.dim_lengths(), test_dsl)
 
                 # will choose smallest increment 'automatically'
                 if new_mult < min_mult:
@@ -87,7 +90,7 @@ def preprocess_prefetch(kernel):
 class FetchLoopNestData(Record):
     pass
 
-def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
+def make_fetch_loop_nest(flnd, fetch_dim_idx, pf_dim_exprs, iname_subst_map,
         implemented_domain):
     pf = flnd.prefetch
     ccm = flnd.c_code_mapper
@@ -98,7 +101,7 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
     from cgen import Assign, For, If
 
     from pymbolic.mapper.substitutor import substitute
-    if pf_iname_idx >= len(pf.inames):
+    if fetch_dim_idx >= len(pf.fetch_dims):
         # done, return
         from pymbolic.primitives import Variable, Subscript
 
@@ -109,11 +112,11 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
                 no_pf_ccm(
                     Subscript(
                         Variable(pf.input_vector),
-                        substitute(pf.index_expr, pf_idx_subst_map)),
+                        substitute(pf.index_expr, iname_subst_map)),
                     PREC_NONE))
 
         def my_ccm(expr):
-            return ccm(substitute(expr, pf_idx_subst_map))
+            return ccm(substitute(expr, iname_subst_map))
 
         from pymbolic.mapper.dependency import DependencyMapper
         check_vars = [v.name for v in DependencyMapper()(pf.index_expr)]
@@ -122,16 +125,73 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
         return wrap_in_bounds_checks(my_ccm, pf.kernel.domain,
                 check_vars, implemented_domain, result)
 
-    pf_iname = pf.inames[pf_iname_idx]
-    realiz_inames = flnd.realization_inames[pf_iname_idx]
+    fetch_inames = pf.fetch_dims[fetch_dim_idx]
+    realiz_inames = flnd.realization_inames[fetch_dim_idx]
 
-    start_index, stop_index = pf.dim_bounds_by_iname[pf_iname]
+    fetch_iname_lengths = [stop-start
+            for start, stop in 
+            [pf.dim_bounds_by_iname[iname] for iname in fetch_inames]]
 
-    dim_length = stop_index-start_index
+    from pytools import product
+    dim_length = product(fetch_iname_lengths)
+
+    idx_var_name = "loopy_prefetch_dim_idx_%d" % fetch_dim_idx
+    idx_var = var(idx_var_name)
 
     if realiz_inames is not None:
         # {{{ parallel fetch
 
+        # {{{ find strides per fetch iname
+
+        fetch_iname_strides = [1]
+        for fil in fetch_iname_lengths[:0:-1]:
+            fetch_iname_strides.insert(0,
+                    fetch_iname_strides[0]*fil)
+
+        # }}}
+
+        idx_var_expr_from_inames = sum(stride*var(iname)
+                for iname, stride in zip(fetch_inames, fetch_iname_strides))
+
+        # {{{ find expressions for each iname from idx_var
+
+        pf_dim_exprs = pf_dim_exprs[:]
+        iname_subst_map = iname_subst_map.copy()
+
+        for i, iname in enumerate(fetch_inames):
+            iname_lower, iname_upper = pf.dim_bounds_by_iname[iname]
+            iname_len = iname_upper-iname_lower
+            iname_val_base = (idx_var // fetch_iname_strides[i])
+            if i != 0:
+                # the outermost iname is the 'largest', no need to
+                # 'modulo away' any larger ones
+                iname_val_base = iname_val_base % iname_len
+
+            pf_dim_exprs.append(iname_val_base)
+            iname_subst_map[iname] = iname_val_base + iname_lower
+
+        # }}}
+
+        # {{{ build an implemented domain with an extra index variable
+
+        from loopy.symbolic import eq_constraint_from_expr
+        idx_var_dim_idx = implemented_domain.get_dim().size(dim_type.set)
+        impl_domain_with_index_var = implemented_domain.add_dims(dim_type.set, 1)
+        impl_domain_with_index_var = (
+                impl_domain_with_index_var
+                .set_dim_name(dim_type.set, idx_var_dim_idx, idx_var_name))
+        aug_space = impl_domain_with_index_var.get_dim()
+        impl_domain_with_index_var = (
+                impl_domain_with_index_var
+                .intersect(
+                    isl.Set.universe(aug_space)
+                    .add_constraint(
+                        eq_constraint_from_expr(
+                            aug_space,
+                            idx_var_expr_from_inames - idx_var))))
+
+        # }}}
+
         realiz_bounds = [
                 flnd.kernel.get_bounds(rn, (rn,), allow_parameters=False)
                 for rn in realiz_inames]
@@ -146,35 +206,43 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
 
         cur_index = 0
 
-        while start_index+cur_index < stop_index:
-            pf_dim_expr = 0
+        while cur_index < dim_length:
+            pf_idx_expr = 0
             for realiz_iname, length in zip(realiz_inames, realiz_lengths):
                 tag = flnd.kernel.iname_to_tag[realiz_iname]
                 from loopy.kernel import TAG_WORK_ITEM_IDX
                 assert isinstance(tag, TAG_WORK_ITEM_IDX)
 
-                pf_dim_expr = (pf_dim_expr*length
+                pf_idx_expr = (pf_idx_expr*length
                         + var("(int) get_local_id(%d)" % tag.axis))
 
-            from loopy.isl import make_slab
-            loop_slab = make_slab(pf.kernel.space, pf_iname,
-                    start_index+cur_index,
-                    min(stop_index, start_index+cur_index+total_realiz_size))
-            new_impl_domain = implemented_domain.intersect(loop_slab)
-
-            pf_dim_expr += cur_index
+            pf_idx_expr += cur_index
 
-            pf_idx_subst_map = pf_idx_subst_map.copy()
-            pf_idx_subst_map[pf_iname] = pf_dim_expr + start_index
-            inner = make_fetch_loop_nest(flnd, pf_iname_idx+1,
-                    pf_dim_exprs+[pf_dim_expr], pf_idx_subst_map,
+            from loopy.isl import make_slab
+            new_impl_domain = (
+                    impl_domain_with_index_var
+                    .intersect(
+                        make_slab(
+                            impl_domain_with_index_var.get_dim(), idx_var_name,
+                            cur_index,
+                            min(dim_length, cur_index+total_realiz_size)))
+                    .project_out(dim_type.set, idx_var_dim_idx, 1))
+
+            inner = make_fetch_loop_nest(flnd, fetch_dim_idx+1,
+                    pf_dim_exprs, iname_subst_map,
                     new_impl_domain)
 
             if cur_index+total_realiz_size > dim_length:
                 inner = wrap_in(If,
-                        "%s < %s" % (ccm(pf_dim_expr), stop_index),
+                        "%s < %s" % (idx_var_name, dim_length),
                         inner)
 
+            from cgen import Initializer, Const, POD
+            inner = gen_code_block([
+                Initializer(Const(POD(np.int32, idx_var_name)),
+                    ccm(pf_idx_expr)),
+                inner], denest=True)
+
             result.append(inner)
 
             cur_index += total_realiz_size
@@ -185,28 +253,28 @@ def make_fetch_loop_nest(flnd, pf_iname_idx, pf_dim_exprs, pf_idx_subst_map,
     else:
         # {{{ sequential fetch
 
-        pf_dim_var = "prefetch_dim_idx_%d" % pf_iname_idx
-        pf_dim_expr = var(pf_dim_var)
+        if len(fetch_inames) > 1:
+            raise NotImplementedError("merged sequential fetches are not supported")
+        pf_iname, = fetch_inames
 
         lb_cns, ub_cns = pf.get_dim_bounds_constraints_by_iname(pf_iname)
 
-        import islpy as isl
         from loopy.isl import cast_constraint_to_space
         loop_slab = (isl.Set.universe(flnd.kernel.space)
                 .add_constraints([cast_constraint_to_space(cns, kernel.space)
                     for cns in [lb_cns, ub_cns]]))
         new_impl_domain = implemented_domain.intersect(loop_slab)
 
-        pf_idx_subst_map = pf_idx_subst_map.copy()
-        pf_idx_subst_map[pf_iname] = pf_dim_expr + start_index
-        inner = make_fetch_loop_nest(flnd, pf_iname_idx+1,
-                pf_dim_exprs+[pf_dim_expr], pf_idx_subst_map,
+        iname_subst_map = iname_subst_map.copy()
+        iname_subst_map[pf_iname] = idx_var + pf.dim_bounds_by_iname[pf_iname][0]
+        inner = make_fetch_loop_nest(flnd, fetch_dim_idx+1,
+                pf_dim_exprs+[idx_var], iname_subst_map,
                 new_impl_domain)
 
         return wrap_in(For,
-                "int %s = 0" % pf_dim_var,
-                "%s < %s" % (pf_dim_var, ccm(dim_length)),
-                "++%s" % pf_dim_var,
+                "int %s = 0" % idx_var_name,
+                "%s < %s" % (idx_var_name, ccm(dim_length)),
+                "++%s" % idx_var_name,
                 inner)
 
         # }}}
@@ -219,17 +287,6 @@ def generate_prefetch_code(cgs, kernel, sched_index, exec_domain):
 
     ccm = cgs.c_code_mapper
 
-    # find surrounding schedule items
-    if sched_index-1 >= 0:
-        next_outer_sched_item = kernel.schedule[sched_index-1]
-    else:
-        next_outer_sched_item = None
-
-    if sched_index+1 < len(kernel.schedule):
-        next_inner_sched_item = kernel.schedule[sched_index+1]
-    else:
-        next_inner_sched_item = None
-
     scheduled_pf = kernel.schedule[sched_index]
     pf = kernel.prefetch[
             scheduled_pf.input_vector, scheduled_pf.index_expr]
@@ -243,7 +300,7 @@ def generate_prefetch_code(cgs, kernel, sched_index, exec_domain):
     # realization_dims is a list of lists of inames, to represent when two dims jointly
     # make up one fetch axis
 
-    realization_inames = [None] * len(pf.inames)
+    realization_inames = [None] * len(pf.fetch_dims)
 
     # {{{ first, fix the user-specified fetch dims
 
@@ -290,26 +347,28 @@ def generate_prefetch_code(cgs, kernel, sched_index, exec_domain):
             for arg in kernel.args
             if isinstance(arg, ScalarArg))
 
-    def stride_key(iname):
-        iname_stride = iname_to_stride[iname]
+    def stride_key(fetch_dim_idx):
+        fetch_dim = pf.fetch_dims[fetch_dim_idx]
 
         from pymbolic import evaluate
-        key = evaluate(iname_stride, approximate_arg_values)
+        key = min(
+                evaluate(iname_to_stride[iname], approximate_arg_values)
+                for iname in fetch_dim)
         assert isinstance(key, int)
         return key
 
-    pf_iname_strides = sorted((iname
-        for dim_idx, iname in enumerate(pf.inames)
+    pf_fetch_dim_strides = sorted((dim_idx
+        for dim_idx in range(len(pf.fetch_dims))
         if realization_inames[dim_idx] is None),
         key=stride_key)
 
-    while knl_work_item_inames and pf_iname_strides:
+    while knl_work_item_inames and pf_fetch_dim_strides:
         # grab least-stride prefetch dim
-        least_stride_pf_iname = pf_iname_strides.pop(0)
+        least_stride_pf_fetch_dim_idx = pf_fetch_dim_strides.pop(0)
 
         # FIXME: It might be good to join multiple things together here
         # for size reasons
-        realization_inames[pf.inames.index(least_stride_pf_iname)] \
+        realization_inames[least_stride_pf_fetch_dim_idx] \
                 = [knl_work_item_inames.pop(0)]
 
     if knl_work_item_inames:
@@ -343,38 +402,83 @@ def generate_prefetch_code(cgs, kernel, sched_index, exec_domain):
 
     # }}}
 
-    new_block = [
-            Comment("prefetch %s[%s] using %s" % (
-                pf.input_vector,
-                ", ".join(pf.inames),
-                ", ".join(
-                        (" x ".join("%s(%s)" % (realiz_iname, kernel.iname_to_tag[realiz_iname])
-                        for realiz_iname in realiz_inames)
-                        if realiz_inames is not None else "loop")
-                        for realiz_inames in realization_inames))),
-            Line(),
-            ]
+    new_block = []
+
+    # {{{ generate comments explaining dimension mapping
+
+    new_block.append(Comment("prefetch %s -- using dimension mapping:" % pf.input_vector))
+    for iaxis, (fetch_dim, realiz_inames) in enumerate(zip(pf.fetch_dims, realization_inames)):
+        new_block.append(Comment("  fetch axis %d:" % iaxis))
+        for iname in fetch_dim:
+            iname_lwr, iname_upr = pf.dim_bounds_by_iname[iname]
+            new_block.append(Comment("      %s [%d..%d)" % (iname, iname_lwr, iname_upr)))
+        new_block.append(Comment("    using:"))
+        for realiz_iname in realiz_inames:
+
+            if realiz_iname is None:
+                new_block.append(Comment("      loop"))
+            else:
+                rd_iname_descr = "loop"
+                iname_lwr, iname_upr, iname_eq = flnd.kernel.get_bounds(realiz_iname, (realiz_iname,),
+                        allow_parameters=False)
+                assert not iname_eq
+
+                new_block.append(Comment("      %s (%s) [%s..%s)"
+                    % (realiz_iname, kernel.iname_to_tag[realiz_iname],
+                        iname_lwr, iname_upr)))
+
+    new_block.append(Line())
+
+    # }}}
+
+    # {{{ omit head sync primitive if possible
 
-    # omit head sync primitive if we just came out of a prefetch
+    head_sync_unneeded_because = None
 
     from loopy.prefetch import LocalMemoryPrefetch
-    if not isinstance(next_outer_sched_item, LocalMemoryPrefetch):
+    if (sched_index-1 >= 0 
+            and isinstance(kernel.schedule[sched_index-1], LocalMemoryPrefetch)):
+        head_sync_unneeded_because = "next outer schedule item is a prefetch"
+
+    from pytools import all
+    from loopy.kernel import ParallelTag
+    from loopy.schedule import ScheduledLoop
+    outer_tags = [
+            kernel.iname_to_tag.get(sched_item.iname)
+            for sched_item in kernel.schedule[:sched_index]
+            if isinstance(sched_item, ScheduledLoop)]
+
+    if not [tag
+            for tag in outer_tags
+            if not isinstance(tag, ParallelTag)]:
+        head_sync_unneeded_because = "no sequential axes nested around fetch"
+
+    # generate (no) head sync code
+    if head_sync_unneeded_because is None:
         new_block.append(S("barrier(CLK_LOCAL_MEM_FENCE)"))
     else:
-        new_block.append(Comment("next outer schedule item is a prefetch: "
-            "no sync needed"))
+        new_block.append(Comment("no sync needed: " + head_sync_unneeded_because))
+        new_block.append(Line())
+
+    # }}}
+
+    new_block.append(fetch_block)
 
-    new_block.extend([
-        fetch_block,
-        ])
+    # {{{ omit tail sync primitive if possible
 
-    # omit tail sync primitive if we're headed into another prefetch
-    if not isinstance(next_inner_sched_item, LocalMemoryPrefetch):
+    tail_sync_unneeded_because = None
+
+    if (sched_index+1 < len(kernel.schedule)
+            and isinstance(kernel.schedule[sched_index+1], LocalMemoryPrefetch)):
+        tail_sync_unneeded_because = "next inner schedule item is a prefetch"
+
+    if tail_sync_unneeded_because is None:
         new_block.append(S("barrier(CLK_LOCAL_MEM_FENCE)"))
     else:
         new_block.append(Line())
-        new_block.append(Comment("next inner schedule item is a prefetch: "
-            "no sync needed"))
+        new_block.append(Comment("no sync needed: " + tail_sync_unneeded_because))
+
+    # }}}
 
     from loopy.codegen.dispatch import build_loop_nest
     new_block.extend([Line(),
@@ -383,3 +487,5 @@ def generate_prefetch_code(cgs, kernel, sched_index, exec_domain):
     return gen_code_block(new_block)
 
 # }}}
+
+# vim: foldmethod=marker
diff --git a/loopy/prefetch.py b/loopy/prefetch.py
index 4f5b436c0..49330236c 100644
--- a/loopy/prefetch.py
+++ b/loopy/prefetch.py
@@ -72,8 +72,10 @@ class LocalMemoryPrefetch(Record):
     :ivar input_vector: A string indicating the input vector variable name.
     :ivar index_expr: An expression identifying the access which this prefetch
       serves.
-    :ivar inames: A sequence of inames (i.e. loop dimensions) identifying which
-        part of the input vector, given the index_expr, should be prefetched.
+    :ivar fetch_dims: A sequence of tuples of inames (i.e. loop dimensions)
+        identifying which part of the input vector, given the index_expr, should
+        be prefetched. Non-length-1 tuples indicate that these indices should
+        share a dimension in the prefetch array.
     :ivar loc_fetch_axes: dictionary from integers 0..len(inames) to lists of
       local index axes which should be used to realize that dimension of the
       prefetch. The last dimension in this list is used as the fastest-changing
@@ -86,11 +88,21 @@ class LocalMemoryPrefetch(Record):
     The latter two values are only assigned during code generation.
     """
 
+    @memoize_method
+    def all_inames(self):
+        """Order matters as this will be the order of indices into the
+        prefetch array.
+        """
+        return [
+                iname
+                for fetch_dim in self.fetch_dims
+                for iname in fetch_dim]
+
     @property
     @memoize_method
     def domain(self):
         return (self.kernel.domain
-                .project_out_except(self.inames, [dim_type.set])
+                .project_out_except(self.all_inames(), [dim_type.set])
                 .compute_divs()
                 .remove_divs_of_dim_type(dim_type.set))
 
@@ -125,7 +137,7 @@ class LocalMemoryPrefetch(Record):
     def dim_bounds_by_iname(self):
         from loopy.codegen.bounds import solve_constraint_for_bound
         result = {}
-        for iname in self.inames:
+        for iname in self.all_inames():
             lower, upper = self.get_dim_bounds_constraints_by_iname(iname)
 
             lower_kind, lower_bound = solve_constraint_for_bound(lower, iname)
@@ -141,28 +153,32 @@ class LocalMemoryPrefetch(Record):
 
         return result
 
-    @property
-    @memoize_method
-    def dim_bounds(self):
-        dbbi = self.dim_bounds_by_iname
-        return [dbbi[iname] for iname in self.inames]
-
     @property
     def itemsize(self):
         return self.kernel.arg_dict[self.input_vector].dtype.itemsize
 
+    def dim_lengths(self):
+        result = []
+        for fetch_dim in self.fetch_dims:
+            fd_result = 1
+            for iname in fetch_dim:
+                start, stop = self.dim_bounds_by_iname[iname]
+                fd_result *= stop-start
+            result.append(fd_result)
+        return result
+
     @property
     @memoize_method
     def nbytes(self):
         from pytools import product
-        return self.itemsize * product(upper-lower for lower, upper in self.dim_bounds)
+        return self.itemsize * product(self.dim_lengths())
 
     @memoize_method
     def free_variables(self):
         from pymbolic.mapper.dependency import DependencyMapper
         return set(var.name
                 for var in DependencyMapper()(self.index_expr)
-                ) - set(self.inames) - self.kernel.scalar_args()
+                ) - set(self.all_inames()) - self.kernel.scalar_args()
 
     def hash(self):
         return (hash(type(self)) ^ hash(self.input_vector)
diff --git a/loopy/schedule.py b/loopy/schedule.py
index bda21a58a..a400f077f 100644
--- a/loopy/schedule.py
+++ b/loopy/schedule.py
@@ -51,7 +51,7 @@ def generate_loop_schedules(kernel, hints=[]):
         # a prefetch variable already scheduled, but not borrowable?
         # (only work item index variables are borrowable)
 
-        if set(pf.inames) & (scheduled_inames - locally_parallel_inames):
+        if set(pf.all_inames()) & (scheduled_inames - locally_parallel_inames):
             # dead end: we won't be able to schedule this prefetch
             # in this branch. at least one of its loop dimensions
             # was already scheduled, and that dimension is not
@@ -104,7 +104,7 @@ def generate_loop_schedules(kernel, hints=[]):
     unsched_prefetch_axes = set(iname
             for pf in kernel.prefetch.itervalues()
             if pf not in prev_schedule
-            for iname in pf.inames
+            for iname in pf.all_inames()
             if not isinstance(kernel.iname_to_tag.get(iname), ParallelTag))
     schedulable -= unsched_prefetch_axes
 
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 3056d8d03..2e2587f1d 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -148,7 +148,7 @@ class LoopyCCodeMapper(CCodeMapper):
                         "[%s - %s]" % (iname, self.rec(
                             pf.dim_bounds_by_iname[iname][0],
                             PREC_SUM))
-                        for iname in pf.inames)
+                        for iname in pf.all_inames())
 
         if isinstance(expr.aggregate, Variable):
             arg = self.kernel.arg_dict[expr.aggregate.name]
diff --git a/test/test_matmul.py b/test/test_matmul.py
index 5fe5faaa5..63d82d7ff 100644
--- a/test/test_matmul.py
+++ b/test/test_matmul.py
@@ -44,8 +44,11 @@ DEBUG_PREAMBLE = r"""
 
 
 def check_error(refsol, sol):
+    if not DO_CHECK:
+        return
+
     rel_err = la.norm(refsol-sol, "fro")/la.norm(refsol, "fro")
-    if DO_CHECK and rel_err > 1e-5:
+    if rel_err > 1e-5 or np.isinf(rel_err) or np.isnan(rel_err):
         if 1:
             import matplotlib.pyplot as pt
             pt.imshow(refsol-sol)
@@ -215,7 +218,7 @@ def test_image_matrix_mul_ilp(ctx_factory):
     knl = lp.split_dimension(knl, "k", 2)
     # conflict-free
     knl = lp.add_prefetch(knl, 'a', ["i_inner", "k_inner"])
-    knl = lp.add_prefetch(knl, 'b', ["j_inner_outer", "j_inner_inner", "k_inner"])
+    knl = lp.add_prefetch(knl, 'b', [("j_inner_outer", "j_inner_inner"), "k_inner"])
     assert knl.get_problems({})[0] <= 2
 
     kernel_gen = (lp.insert_register_prefetches(knl)
-- 
GitLab