diff --git a/MEMO b/MEMO
index 95d8f452775438126ae2eb75f008cafdd13c2745..50a45695db86a6f8bbddc64491bd8b207f281a95 100644
--- a/MEMO
+++ b/MEMO
@@ -52,16 +52,11 @@ To-do
 
 - reg rolling
 
-- nbody GPU
-  -> pending better prefetch spec
-  - Prefetch by sample access
-  - Exclude by precompute name
-
 - Expose iname-duplicate-and-rename as a primitive.
 
 - add_prefetch gets a flag to separate out each access
 
-- Allow parameters to be varying during run-time varying, substituting values
+- Allow parameters to be varying during run-time, substituting values
   that depend on other inames?
 
 - Fix all tests
@@ -112,6 +107,10 @@ Future ideas
 Dealt with
 ^^^^^^^^^^
 
+- nbody GPU
+  -> pending better prefetch spec
+  - Prefetch by sample access
+
 - How is intra-instruction ordering of ILP loops going to be determined?
   (taking into account that it could vary even per-instruction?)
 
diff --git a/doc/reference.rst b/doc/reference.rst
index 0c5128490393c0d961fbb7ac3440922f5b5de5bd..3202fe000b5c3b5c97290fc0a005f0fbf37da7c6 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -131,7 +131,7 @@ Dealing with Substitution Rules
 
 .. autofunction:: extract_subst
 
-.. autofunction:: apply_subst
+.. autofunction:: expand_subst
 
 Precomputation and Prefetching
 ------------------------------
diff --git a/loopy/__init__.py b/loopy/__init__.py
index e1745099d236a0a4fd94b04d33b973dda3acc731..0043a21b1ba9082b6b076800580f27eb651e1e86 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -23,7 +23,7 @@ class LoopyAdvisory(UserWarning):
 from loopy.kernel import ScalarArg, ArrayArg, ConstantArrayArg, ImageArg
 
 from loopy.kernel import AutoFitLocalIndexTag, get_dot_dependency_graph, LoopKernel
-from loopy.subst import extract_subst, apply_subst
+from loopy.subst import extract_subst, expand_subst
 from loopy.cse import precompute
 from loopy.preprocess import preprocess_kernel, realize_reduction
 from loopy.schedule import generate_loop_schedules
@@ -33,12 +33,13 @@ from loopy.check import check_kernels
 
 __all__ = ["ScalarArg", "ArrayArg", "ConstantArrayArg", "ImageArg", "LoopKernel",
         "get_dot_dependency_graph",
-        "preprocess_kernel", "generate_loop_schedules",
+        "preprocess_kernel", "realize_reduction",
+        "generate_loop_schedules",
         "generate_code",
         "CompiledKernel", "drive_timing_run", "auto_test_vs_ref", "check_kernels",
         "make_kernel", "split_dimension", "join_dimensions",
         "tag_dimensions",
-        "extract_subst", "apply_subst",
+        "extract_subst", "expand_subst",
         "precompute", "add_prefetch"
         ]
 
@@ -422,14 +423,14 @@ def tag_dimensions(kernel, iname_to_tag, force=False):
 # {{{ convenience: add_prefetch
 
 def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
-        default_tag="l.auto", rule_name=None, footprint_indices=None):
+        default_tag="l.auto", rule_name=None, footprint_subscripts=None):
     """Prefetch all accesses to the variable *var_name*, with all accesses
     being swept through *sweep_inames*.
 
     :ivar dim_arg_names: List of names representing each fetch axis.
     :ivar rule_name: base name of the generated temporary variable.
-    :ivar footprint_indices: A list of tuples indicating the index set used
-        to generate the footprint.
+    :ivar footprint_subscripts: A list of tuples indicating the index (i.e.
+        subscript) tuples used to generate the footprint.
 
         If only one such set of indices is desired, this may also be specified
         directly by putting an index expression into *var_name*. Substitutions
@@ -437,7 +438,7 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
         applied to these indices.
     """
 
-    # {{{ fish indexing out of var_name and into sweep_indices
+    # {{{ fish indexing out of var_name and into footprint_subscripts
 
     from loopy.symbolic import parse
     parsed_var_name = parse(var_name)
@@ -447,13 +448,13 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
         # nothing to see
         pass
     elif isinstance(parsed_var_name, Subscript):
-        if footprint_indices is not None:
-            raise TypeError("if footprint_indices is specified, then var_name "
+        if footprint_subscripts is not None:
+            raise TypeError("if footprint_subscripts is specified, then var_name "
                     "may not contain a subscript")
 
         assert isinstance(parsed_var_name.aggregate, Variable)
         var_name = parsed_var_name.aggregate.name
-        sweep_indices = [parsed_var_name.index]
+        footprint_subscripts = [parsed_var_name.index]
     else:
         raise ValueError("var_name must either be a variable name or a subscript")
 
@@ -486,20 +487,13 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
 
     kernel = extract_subst(kernel, rule_name, uni_template, parameters)
 
-    new_fetch_dims = []
-    for fd in sweep_inames:
-        if isinstance(fd, int):
-            new_fetch_dims.append(parameters[fd])
-        else:
-            new_fetch_dims.append(fd)
-
     footprint_generators = None
 
-    if sweep_indices is not None:
-        if not isinstance(sweep_indices, (list, tuple)):
-            sweep_indices = [sweep_indices]
+    if footprint_subscripts is not None:
+        if not isinstance(footprint_subscripts, (list, tuple)):
+            footprint_subscripts = [footprint_subscripts]
 
-        def standardize_sweep_indices(si):
+        def standardize_footprint_indices(si):
             if isinstance(si, str):
                 from loopy.symbolic import parse
                 si = parse(si)
@@ -517,11 +511,11 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
 
             return si
 
-        sweep_indices = [standardize_sweep_indices(si) for si in sweep_indices]
+        footprint_subscripts = [standardize_footprint_indices(si) for si in footprint_subscripts]
 
         from pymbolic.primitives import Variable
         footprint_generators = [
-                Variable(var_name)(*si) for si in sweep_indices]
+                Variable(var_name)(*si) for si in footprint_subscripts]
 
     new_kernel = precompute(kernel, rule_name, arg.dtype, sweep_inames,
             footprint_generators=footprint_generators,
@@ -531,7 +525,7 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
     # If the rule survived past precompute() (i.e. some accesses fell outside
     # the footprint), get rid of it before moving on.
     if rule_name in new_kernel.substitutions:
-        return apply_subst(new_kernel, rule_name)
+        return expand_subst(new_kernel, rule_name)
     else:
         return new_kernel
 
diff --git a/loopy/cse.py b/loopy/cse.py
index db1da2d19ae46f19b1b2299573d3bae490a2e44a..6c8455b9a2a9c9e414a2b40e834c111154b923ac 100644
--- a/loopy/cse.py
+++ b/loopy/cse.py
@@ -13,7 +13,18 @@ from pymbolic import var
 
 
 class InvocationDescriptor(Record):
-    __slots__ = ["expr", "args", "expands_footprint", "is_in_footprint"]
+    __slots__ = [
+            "expr",
+            "args",
+            "expands_footprint",
+            "is_in_footprint",
+
+            # Record from which substitution rule this invocation of the rule
+            # being precomputed originated. If all invocations end up being
+            # in-footprint, then the replacement with the prefetch can be made
+            # within the rule.
+            "from_subst_rule"
+            ]
 
 
 
@@ -91,7 +102,8 @@ def build_per_access_storage_to_sweep_map(invdesc, domain_dup_sweep,
     # stor2sweep is back in map_space
     return stor2sweep
 
-def build_global_storage_to_sweep_map(invocation_descriptors, domain_dup_sweep,
+def build_global_storage_to_sweep_map(invocation_descriptors,
+        dup_sweep_index, domain_dup_sweep,
         storage_axis_names, storage_axis_sources, prime_sweep_inames):
     """
     As a side effect, this fills out is_in_footprint in the
@@ -120,6 +132,15 @@ def build_global_storage_to_sweep_map(invocation_descriptors, domain_dup_sweep,
         global_stor2sweep = isl.Map.from_basic_map(stor2sweep)
     global_stor2sweep = global_stor2sweep.intersect_range(domain_dup_sweep)
 
+    # function to move non-sweep inames into parameter space
+    def move_non_sweep_to_par(s2smap):
+        sp = s2smap.get_space()
+        return s2smap.move_dims(
+                dim_type.param, sp.dim(dim_type.param),
+                dim_type.out, 0, dup_sweep_index)
+
+    global_s2s_par_dom = move_non_sweep_to_par(global_stor2sweep).domain()
+
     # check if non-footprint-building invocation descriptors fall into footprint
     for invdesc in invocation_descriptors:
         stor2sweep = build_per_access_storage_to_sweep_map(invdesc, domain_dup_sweep,
@@ -128,13 +149,16 @@ def build_global_storage_to_sweep_map(invocation_descriptors, domain_dup_sweep,
         if isinstance(stor2sweep, isl.BasicMap):
             stor2sweep = isl.Map.from_basic_map(stor2sweep)
 
-        stor2sweep = stor2sweep.intersect_range(domain_dup_sweep)
+        stor2sweep = move_non_sweep_to_par(
+                stor2sweep.intersect_range(domain_dup_sweep))
+
+        is_in_footprint = stor2sweep.domain().is_subset(
+                global_s2s_par_dom)
 
         if not invdesc.expands_footprint:
-            invdesc.is_in_footprint = stor2sweep.is_subset(global_stor2sweep)
+            invdesc.is_in_footprint = is_in_footprint
         else:
-            assert stor2sweep.domain().is_subset(global_stor2sweep.domain())
-
+            assert is_in_footprint
 
     return global_stor2sweep
 
@@ -193,7 +217,7 @@ def get_access_info(kernel, subst_name,
     # }}}
 
     stor2sweep = build_global_storage_to_sweep_map(
-            invocation_descriptors, domain_dup_sweep,
+            invocation_descriptors, dup_sweep_index, domain_dup_sweep,
             storage_axis_names, storage_axis_sources, prime_sweep_inames)
 
     storage_base_indices, storage_shape = compute_bounds(
@@ -281,13 +305,13 @@ def simplify_via_aff(expr):
 
 
 
-def precompute(kernel, subst_name, dtype, sweep_axes=[],
+def precompute(kernel, subst_name, dtype, sweep_inames=[],
         footprint_generators=None,
         storage_axes=None, new_storage_axis_names=None, storage_axis_to_tag={},
         default_tag="l.auto"):
     """Precompute the expression described in the substitution rule *subst_name*
     and store it in a temporary array. A precomputation needs two things to operate,
-    a list of *sweep_axes* (order irrelevant) and an ordered list of *storage_axes*
+    a list of *sweep_inames* (order irrelevant) and an ordered list of *storage_axes*
     (whose order will describe the axis ordering of the temporary array).
 
     *subst_name* may contain a period (".") to filter out a subset of the
@@ -308,11 +332,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
       within its arguments. A new, dedicated storage axis is allocated for
       such an axis.
 
-    * The axis is a formal argument name of the substitution rule.
-      This is equivalent to specifying *all* inames occurring within
-      the so-named formal argument at *all* usage sites.
-
-    :arg sweep_axes: A :class:`list` of inames and/or rule argument names to be swept.
+    :arg sweep_inames: A :class:`list` of inames and/or rule argument names to be swept.
     :arg storage_axes: A :class:`list` of inames and/or rule argument names/indices to be used as storage axes.
 
     If `storage_axes` is not specified, it defaults to the arrangement
@@ -323,6 +343,19 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
     eliminated.
     """
 
+    # {{{ check arguments
+
+    for iname in sweep_inames:
+        if iname not in kernel.all_inames():
+            raise RuntimeError("sweep iname '%s' is not a known iname"
+                    % iname)
+
+    if footprint_generators is not None:
+        if isinstance(footprint_generators, str):
+            footprint_generators = [footprint_generators]
+
+    # }}}
+
     from loopy.symbolic import SubstitutionCallbackMapper
 
     c_subst_name = subst_name.replace(".", "_")
@@ -334,27 +367,11 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
     subst = kernel.substitutions[subst_name]
     arg_names = subst.arguments
 
-    # {{{ gather up invocations
+    # {{{ create list of invocation descriptors
 
     invocation_descriptors = []
 
-    def gather_substs(expr, name, instance, args, rec):
-        if len(args) != len(subst.arguments):
-            raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
-                    % (subst_name, len(args), len(subst.arguments), ))
-
-        arg_deps = get_dependencies(args)
-        if not arg_deps <= kernel.all_inames():
-            raise RuntimeError("CSE arguments in '%s' do not consist "
-                    "exclusively of inames" % expr)
-
-        invocation_descriptors.append(
-                InvocationDescriptor(expr=expr, args=args,
-                    expands_footprint=footprint_generators is None))
-        return expr
-
-    from loopy.symbolic import SubstitutionCallbackMapper
-    scm = SubstitutionCallbackMapper([(subst_name, subst_instance)], gather_substs)
+    # {{{ process invocations in footprint generators
 
     if footprint_generators:
         for fpg in footprint_generators:
@@ -373,56 +390,79 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
 
             invocation_descriptors.append(
                     InvocationDescriptor(expr=fpg, args=args,
-                        expands_footprint=True))
+                        expands_footprint=True,
+                        from_subst_rule=None))
+
+    # }}}
+
+    # {{{ gather up invocations in kernel code
+
+    current_subst_rule_stack = []
 
     # We need to work on the fully expanded form of an expression.
     # To that end, instantiate a substitutor.
     from loopy.symbolic import ParametrizedSubstitutor
     rules_except_mine = kernel.substitutions.copy()
     del rules_except_mine[subst_name]
-    subst_expander = ParametrizedSubstitutor(rules_except_mine)
+    subst_expander = ParametrizedSubstitutor(rules_except_mine,
+            one_level=True)
 
-    for insn in kernel.instructions:
-        # We can't deal with invocations that involve other substitution's
-        # arguments. Therefore, fully expand each instruction and look at
-        # the invocations in subst_name occurring there.
+    def gather_substs(expr, name, instance, args, rec):
+        if subst_name != name:
+            if name in subst_expander.rules:
+                # We can't deal with invocations that involve other substitution's
+                # arguments. Therefore, fully expand each encountered substitution
+                # rule and look at the invocations of subst_name occurring in its
+                # body.
+
+                expanded_expr = subst_expander(expr)
+                current_subst_rule_stack.append(name)
+                result = rec(expanded_expr)
+                current_subst_rule_stack.pop()
+                return result
 
-        expanded_expr = subst_expander(insn.expression)
-        scm(expanded_expr)
+            else:
+                return None
 
-    if not invocation_descriptors:
-        raise RuntimeError("no invocations of '%s' found" % subst_name)
+        if subst_instance != instance:
+            # use fall-back identity mapper
+            return None
 
-    # }}}
+        if len(args) != len(subst.arguments):
+            raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
+                    % (subst_name, len(args), len(subst.arguments), ))
 
-    # {{{ deal with argument names as sweep axes
+        arg_deps = get_dependencies(args)
+        if not arg_deps <= kernel.all_inames():
+            raise RuntimeError("CSE arguments in '%s' do not consist "
+                    "exclusively of inames" % expr)
 
-    # An argument name as a sweep iname means that *all*
-    # inames contained in *all* uses of the rule will be
-    # made sweep inames.
+        if current_subst_rule_stack:
+            current_subst_rule = current_subst_rule_stack[-1]
+        else:
+            current_subst_rule = None
 
-    sweep_inames = set()
+        invocation_descriptors.append(
+                InvocationDescriptor(expr=expr, args=args,
+                    expands_footprint=footprint_generators is None,
+                    from_subst_rule=current_subst_rule))
+        return expr
 
-    for invdesc in invocation_descriptors:
-        if not invdesc.expands_footprint:
-            continue
+    from loopy.symbolic import SubstitutionCallbackMapper
+    scm = SubstitutionCallbackMapper(names_filter=None, func=gather_substs)
 
-        for swaxis in sweep_axes:
-            if isinstance(swaxis, int):
-                sweep_inames.update(
-                        get_dependencies(invdesc.args[swaxis]))
-            elif swaxis in subst.arguments:
-                arg_idx = subst.arguments.index(swaxis)
-                sweep_inames.update(
-                        get_dependencies(invdesc.args[arg_idx]))
-            else:
-                sweep_inames.add(swaxis)
+    for insn in kernel.instructions:
+        scm(insn.expression)
 
-    sweep_inames = list(sweep_inames)
-    del sweep_axes
+    if not invocation_descriptors:
+        raise RuntimeError("no invocations of '%s' found" % subst_name)
 
     # }}}
 
+    # }}}
+
+    sweep_inames = list(sweep_inames)
+
     # {{{ see if we need extra storage dimensions
 
     # find inames used in argument dependencies
@@ -457,8 +497,6 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
     storage_axis_name_to_tag = {}
 
     for i, saxis in enumerate(storage_axes):
-        new_name = None
-
         tag_lookup_saxis = saxis
 
         if saxis in subst.arguments:
@@ -509,7 +547,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
                             storage_axis_names, storage_axis_sources,
                             sweep_inames, invocation_descriptors)
 
-    # {{{ ensure convexity of new_domain
+    # {{{ try a few ways to get new_domain to be convex
 
     if len(new_domain.get_basic_sets()) > 1:
         hull_new_domain = new_domain.simple_hull()
@@ -587,17 +625,62 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
             left_unused_subst_rule_invocations[0] = True
             return expr
 
-        found = False
-        for invdesc in invocation_descriptors:
-            if expr == invdesc.expr:
-                found = True
-                break
+        # {{{ check if current use is in-footprint
 
-        if not invdesc.is_in_footprint:
-            left_unused_subst_rule_invocations[0] = True
-            return expr
+        if current_subst_rule is None:
+            # The current subsitution was *not* found inside another
+            # substitution rule. Try and dig up the corresponding invocation
+            # descriptor.
+
+            found = False
+            for invdesc in invocation_descriptors:
+                if expr == invdesc.expr:
+                    found = True
+                    break
 
-        assert found, expr
+            if not invdesc.is_in_footprint:
+                left_unused_subst_rule_invocations[0] = True
+                return expr
+
+            assert found, expr
+
+        else:
+            # The current subsitution *was* found inside another substitution
+            # rule. We can't dig up the corresponding invocation descriptor,
+            # because it was the result of expanding that outer substitution
+            # rule. But we do know what the current outer substitution rule is,
+            # and we can check if all uses within that rule were uniformly
+            # in-footprint. If so, we'll go ahead, otherwise we'll bomb out.
+
+            current_rule_invdescs_in_footprint = [
+                    invdesc.is_in_footprint
+                    for invdesc in invocation_descriptors
+                    if invdesc.from_subst_rule == current_subst_rule]
+
+            from pytools import all
+            all_in = all(current_rule_invdescs_in_footprint)
+            all_out = all(not b for b in current_rule_invdescs_in_footprint)
+
+            assert not (all_in and all_out)
+
+            if not (all_in or all_out):
+                raise RuntimeError("substitution '%s' (being precomputed) is used "
+                        "from within substitution '%s', but not all uses of "
+                        "'%s' within '%s' "
+                        "are uniformly within-footprint or outside of the footprint, "
+                        "making a unique replacement of '%s' impossible. Please expand "
+                        "'%s' and try again."
+                        % (subst_name, current_subst_rule,
+                            subst_name, current_subst_rule,
+                            subst_name, current_subst_rule))
+
+            if all_out:
+                left_unused_subst_rule_invocations[0] = True
+                return expr
+
+            assert all_in
+
+        # }}}
 
         if len(args) != len(subst.arguments):
             raise ValueError("invocation of '%s' with too few arguments"
@@ -628,15 +711,18 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
 
     new_insns = [compute_insn]
 
+    current_subst_rule = None
     sub_map = SubstitutionCallbackMapper([subst_name], do_substs)
     for insn in kernel.instructions:
         new_insn = insn.copy(expression=sub_map(insn.expression))
         new_insns.append(new_insn)
 
     # also catch uses of our rule in other substitution rules
-    new_substs = dict(
-            (s.name, s.copy(expression=sub_map(s.expression)))
-            for s in kernel.substitutions.itervalues())
+    new_substs = {}
+    for s in kernel.substitutions.itervalues():
+        current_subst_rule = s.name
+        new_substs[s.name] = s.copy(
+                expression=sub_map(s.expression))
 
     # If the subst above caught all uses of the subst rule, get rid of it.
     if not left_unused_subst_rule_invocations[0]:
@@ -658,5 +744,4 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
 
 
 
-
 # vim: foldmethod=marker
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 5c86369a395278937765629f76352ecf219afe83..1b1388653880ec2b0a2af6617571ad03ada9fbec 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -736,8 +736,8 @@ def adjust_local_temp_var_storage(kernel):
 
 
 def preprocess_kernel(kernel):
-    from loopy.subst import apply_subst
-    kernel = apply_subst(kernel)
+    from loopy.subst import expand_subst
+    kernel = expand_subst(kernel)
 
     kernel = realize_reduction(kernel)
 
diff --git a/loopy/subst.py b/loopy/subst.py
index 636252304263587e59dd4718dea63d199acb2274..1ed9788cabe1ed37d5b5370690d9891f0293dfb7 100644
--- a/loopy/subst.py
+++ b/loopy/subst.py
@@ -172,7 +172,7 @@ def extract_subst(kernel, subst_name, template, parameters):
 
 
 
-def apply_subst(kernel, subst_name=None):
+def expand_subst(kernel, subst_name=None):
     if subst_name is None:
         rules = kernel.substitutions
     else:
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 2b1d584726d27d1d5c85ca44b2f53fb62d98db98..1ed6eb4fc3d21df25363acb62236fc05cf2200fd 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -545,11 +545,17 @@ class SubstitutionCallbackMapper(IdentityMapper):
 # {{{ parametrized substitutor
 
 class ParametrizedSubstitutor(object):
-    def __init__(self, rules):
+    def __init__(self, rules, one_level=False):
         self.rules = rules
+        self.one_level = one_level
 
     def __call__(self, expr):
+        level = [0]
+
         def expand_if_known(expr, name, instance, args, rec):
+            if self.one_level and level[0] > 0:
+                return None
+
             rule = self.rules[name]
             if len(rule.arguments) != len(args):
                 raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
@@ -559,7 +565,11 @@ class ParametrizedSubstitutor(object):
             subst_map = SubstitutionMapper(make_subst_func(
                 dict(zip(rule.arguments, args))))
 
-            return rec(subst_map(rule.expression))
+            level[0] += 1
+            result = rec(subst_map(rule.expression))
+            level[0] -= 1
+
+            return result
 
         scm = SubstitutionCallbackMapper(self.rules.keys(), expand_if_known)
         return scm(expr)
diff --git a/test/test_nbody.py b/test/test_nbody.py
index fb2cee8c26e8585c3ab05e9d855a4807805a2d53..4f888ef09d0c4a66b23f90e06664fc0f36d89588 100644
--- a/test/test_nbody.py
+++ b/test/test_nbody.py
@@ -44,6 +44,7 @@ def test_nbody(ctx_factory):
         return knl, []
 
     def variant_gpu(knl):
+        knl = lp.expand_subst(knl)
         knl = lp.split_dimension(knl, "i", 256,
                 outer_tag="g.0", inner_tag="l.0", slabs=(0,1))
         knl = lp.split_dimension(knl, "j", 256, slabs=(0,1))
diff --git a/test/test_sem_reagan.py b/test/test_sem_reagan.py
index 8d31024bb85d0674067ecd5a8873a9363ed2afea..bdccb112f2c0524bdadd5126c4cbdaae7c023b4c 100644
--- a/test/test_sem_reagan.py
+++ b/test/test_sem_reagan.py
@@ -47,24 +47,28 @@ def test_tim2d(ctx_factory):
             # lp.ImageArg("D", dtype, shape=(n, n)),
             lp.ScalarArg("K", np.int32, approximately=1000),
             ],
-             name="semlap2D", assumptions="K>=1")
+            name="semlap2D", assumptions="K>=1")
 
     seq_knl = knl
 
     def variant_orig(knl):
+        knl = lp.tag_dimensions(knl, dict(i="l.0", j="l.1", e="g.0"))
+
         knl = lp.add_prefetch(knl, "D", ["m", "j", "i","o"])
         knl = lp.add_prefetch(knl, "u", ["i", "j",  "o"])
-        knl = lp.precompute(knl, "ur", np.float32, ["a", "b"])
-        knl = lp.precompute(knl, "us", np.float32, ["a", "b"])
-        knl = lp.split_dimension(knl, "e", 1, outer_tag="g.0")#, slabs=(0, 1))
 
-        knl = lp.tag_dimensions(knl, dict(i="l.0", j="l.1"))
+        knl = lp.precompute(knl, "ur", np.float32, ["m", "j"], "ur(m,j)")
+        knl = lp.precompute(knl, "us", np.float32, ["i", "m"], "us(i,m)")
+
+        knl = lp.add_prefetch(knl, "G")
+
+        knl = lp.precompute(knl, "Gux", np.float32, ["m", "j"], "Gux(m,j)")
+        knl = lp.precompute(knl, "Guy", np.float32, ["i", "m"], "Gux(i,m)")
+
         knl = lp.tag_dimensions(knl, dict(o="unr"))
         knl = lp.tag_dimensions(knl, dict(m="unr"))
 
-
-        # knl = lp.add_prefetch(knl, "G", [2,3], default_tag=None) # axis/argument indices on G
-        knl = lp.add_prefetch(knl, "G", [2,3]) # axis/argument indices on G
+        return knl
 
     def variant_prefetch(knl):
         knl = lp.precompute(knl, "ur", np.float32, ["a", "b"])
@@ -94,8 +98,8 @@ def test_tim2d(ctx_factory):
         knl = lp.precompute(knl, "Guy", np.float32, ["a", "b"])
         return knl
 
-    #for variant in [variant_orig]:
-    for variant in [variant_1]:
+    for variant in [variant_orig]:
+    #for variant in [variant_1]:
         kernel_gen = lp.generate_loop_schedules(variant(knl))
         kernel_gen = lp.check_kernels(kernel_gen, dict(K=1000))
 
@@ -103,7 +107,7 @@ def test_tim2d(ctx_factory):
         lp.auto_test_vs_ref(seq_knl, ctx, kernel_gen,
                 op_count=K*(n*n*n*2*2 + n*n*2*3 + n**3 * 2*2)/1e9,
                 op_label="GFlops",
-                parameters={"K": K})
+                parameters={"K": K}, print_ref_code=True)