From 6d8feb56d9588f695830a3e49ab19586ef57a281 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 3 Mar 2012 15:59:15 -0600
Subject: [PATCH] Multple things, mainly related to prefetch.

- Rename apply_subst to expand_subst.

- Track if precomputes in other substitution rules are uniformly
  in-footprint, and if so, perform substitution to storage reference
  even within substitution rules.

- Remove 'formal argument' sweep spec from precompute(). This is purely
  counterintuitive, because it finds an intermediate set of inames that
  then get swept out and thereby may pick up much more than the formal
  argument referenced.

- Fix the 'in-footprint' criterion to be the storage-to-sweep map with
  the sweep dimensions projected out.
---
 MEMO                    |  11 +-
 doc/reference.rst       |   2 +-
 loopy/__init__.py       |  42 +++----
 loopy/cse.py            | 243 +++++++++++++++++++++++++++-------------
 loopy/preprocess.py     |   4 +-
 loopy/subst.py          |   2 +-
 loopy/symbolic.py       |  14 ++-
 test/test_nbody.py      |   1 +
 test/test_sem_reagan.py |  26 +++--
 9 files changed, 219 insertions(+), 126 deletions(-)

diff --git a/MEMO b/MEMO
index 95d8f4527..50a45695d 100644
--- a/MEMO
+++ b/MEMO
@@ -52,16 +52,11 @@ To-do
 
 - reg rolling
 
-- nbody GPU
-  -> pending better prefetch spec
-  - Prefetch by sample access
-  - Exclude by precompute name
-
 - Expose iname-duplicate-and-rename as a primitive.
 
 - add_prefetch gets a flag to separate out each access
 
-- Allow parameters to be varying during run-time varying, substituting values
+- Allow parameters to be varying during run-time, substituting values
   that depend on other inames?
 
 - Fix all tests
@@ -112,6 +107,10 @@ Future ideas
 Dealt with
 ^^^^^^^^^^
 
+- nbody GPU
+  -> pending better prefetch spec
+  - Prefetch by sample access
+
 - How is intra-instruction ordering of ILP loops going to be determined?
   (taking into account that it could vary even per-instruction?)
 
diff --git a/doc/reference.rst b/doc/reference.rst
index 0c5128490..3202fe000 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -131,7 +131,7 @@ Dealing with Substitution Rules
 
 .. autofunction:: extract_subst
 
-.. autofunction:: apply_subst
+.. autofunction:: expand_subst
 
 Precomputation and Prefetching
 ------------------------------
diff --git a/loopy/__init__.py b/loopy/__init__.py
index e1745099d..0043a21b1 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -23,7 +23,7 @@ class LoopyAdvisory(UserWarning):
 from loopy.kernel import ScalarArg, ArrayArg, ConstantArrayArg, ImageArg
 
 from loopy.kernel import AutoFitLocalIndexTag, get_dot_dependency_graph, LoopKernel
-from loopy.subst import extract_subst, apply_subst
+from loopy.subst import extract_subst, expand_subst
 from loopy.cse import precompute
 from loopy.preprocess import preprocess_kernel, realize_reduction
 from loopy.schedule import generate_loop_schedules
@@ -33,12 +33,13 @@ from loopy.check import check_kernels
 
 __all__ = ["ScalarArg", "ArrayArg", "ConstantArrayArg", "ImageArg", "LoopKernel",
         "get_dot_dependency_graph",
-        "preprocess_kernel", "generate_loop_schedules",
+        "preprocess_kernel", "realize_reduction",
+        "generate_loop_schedules",
         "generate_code",
         "CompiledKernel", "drive_timing_run", "auto_test_vs_ref", "check_kernels",
         "make_kernel", "split_dimension", "join_dimensions",
         "tag_dimensions",
-        "extract_subst", "apply_subst",
+        "extract_subst", "expand_subst",
         "precompute", "add_prefetch"
         ]
 
@@ -422,14 +423,14 @@ def tag_dimensions(kernel, iname_to_tag, force=False):
 # {{{ convenience: add_prefetch
 
 def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
-        default_tag="l.auto", rule_name=None, footprint_indices=None):
+        default_tag="l.auto", rule_name=None, footprint_subscripts=None):
     """Prefetch all accesses to the variable *var_name*, with all accesses
     being swept through *sweep_inames*.
 
     :ivar dim_arg_names: List of names representing each fetch axis.
     :ivar rule_name: base name of the generated temporary variable.
-    :ivar footprint_indices: A list of tuples indicating the index set used
-        to generate the footprint.
+    :ivar footprint_subscripts: A list of tuples indicating the index (i.e.
+        subscript) tuples used to generate the footprint.
 
         If only one such set of indices is desired, this may also be specified
         directly by putting an index expression into *var_name*. Substitutions
@@ -437,7 +438,7 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
         applied to these indices.
     """
 
-    # {{{ fish indexing out of var_name and into sweep_indices
+    # {{{ fish indexing out of var_name and into footprint_subscripts
 
     from loopy.symbolic import parse
     parsed_var_name = parse(var_name)
@@ -447,13 +448,13 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
         # nothing to see
         pass
     elif isinstance(parsed_var_name, Subscript):
-        if footprint_indices is not None:
-            raise TypeError("if footprint_indices is specified, then var_name "
+        if footprint_subscripts is not None:
+            raise TypeError("if footprint_subscripts is specified, then var_name "
                     "may not contain a subscript")
 
         assert isinstance(parsed_var_name.aggregate, Variable)
         var_name = parsed_var_name.aggregate.name
-        sweep_indices = [parsed_var_name.index]
+        footprint_subscripts = [parsed_var_name.index]
     else:
         raise ValueError("var_name must either be a variable name or a subscript")
 
@@ -486,20 +487,13 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
 
     kernel = extract_subst(kernel, rule_name, uni_template, parameters)
 
-    new_fetch_dims = []
-    for fd in sweep_inames:
-        if isinstance(fd, int):
-            new_fetch_dims.append(parameters[fd])
-        else:
-            new_fetch_dims.append(fd)
-
     footprint_generators = None
 
-    if sweep_indices is not None:
-        if not isinstance(sweep_indices, (list, tuple)):
-            sweep_indices = [sweep_indices]
+    if footprint_subscripts is not None:
+        if not isinstance(footprint_subscripts, (list, tuple)):
+            footprint_subscripts = [footprint_subscripts]
 
-        def standardize_sweep_indices(si):
+        def standardize_footprint_indices(si):
             if isinstance(si, str):
                 from loopy.symbolic import parse
                 si = parse(si)
@@ -517,11 +511,11 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
 
             return si
 
-        sweep_indices = [standardize_sweep_indices(si) for si in sweep_indices]
+        footprint_subscripts = [standardize_footprint_indices(si) for si in footprint_subscripts]
 
         from pymbolic.primitives import Variable
         footprint_generators = [
-                Variable(var_name)(*si) for si in sweep_indices]
+                Variable(var_name)(*si) for si in footprint_subscripts]
 
     new_kernel = precompute(kernel, rule_name, arg.dtype, sweep_inames,
             footprint_generators=footprint_generators,
@@ -531,7 +525,7 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
     # If the rule survived past precompute() (i.e. some accesses fell outside
     # the footprint), get rid of it before moving on.
     if rule_name in new_kernel.substitutions:
-        return apply_subst(new_kernel, rule_name)
+        return expand_subst(new_kernel, rule_name)
     else:
         return new_kernel
 
diff --git a/loopy/cse.py b/loopy/cse.py
index db1da2d19..6c8455b9a 100644
--- a/loopy/cse.py
+++ b/loopy/cse.py
@@ -13,7 +13,18 @@ from pymbolic import var
 
 
 class InvocationDescriptor(Record):
-    __slots__ = ["expr", "args", "expands_footprint", "is_in_footprint"]
+    __slots__ = [
+            "expr",
+            "args",
+            "expands_footprint",
+            "is_in_footprint",
+
+            # Record from which substitution rule this invocation of the rule
+            # being precomputed originated. If all invocations end up being
+            # in-footprint, then the replacement with the prefetch can be made
+            # within the rule.
+            "from_subst_rule"
+            ]
 
 
 
@@ -91,7 +102,8 @@ def build_per_access_storage_to_sweep_map(invdesc, domain_dup_sweep,
     # stor2sweep is back in map_space
     return stor2sweep
 
-def build_global_storage_to_sweep_map(invocation_descriptors, domain_dup_sweep,
+def build_global_storage_to_sweep_map(invocation_descriptors,
+        dup_sweep_index, domain_dup_sweep,
         storage_axis_names, storage_axis_sources, prime_sweep_inames):
     """
     As a side effect, this fills out is_in_footprint in the
@@ -120,6 +132,15 @@ def build_global_storage_to_sweep_map(invocation_descriptors, domain_dup_sweep,
         global_stor2sweep = isl.Map.from_basic_map(stor2sweep)
     global_stor2sweep = global_stor2sweep.intersect_range(domain_dup_sweep)
 
+    # function to move non-sweep inames into parameter space
+    def move_non_sweep_to_par(s2smap):
+        sp = s2smap.get_space()
+        return s2smap.move_dims(
+                dim_type.param, sp.dim(dim_type.param),
+                dim_type.out, 0, dup_sweep_index)
+
+    global_s2s_par_dom = move_non_sweep_to_par(global_stor2sweep).domain()
+
     # check if non-footprint-building invocation descriptors fall into footprint
     for invdesc in invocation_descriptors:
         stor2sweep = build_per_access_storage_to_sweep_map(invdesc, domain_dup_sweep,
@@ -128,13 +149,16 @@ def build_global_storage_to_sweep_map(invocation_descriptors, domain_dup_sweep,
         if isinstance(stor2sweep, isl.BasicMap):
             stor2sweep = isl.Map.from_basic_map(stor2sweep)
 
-        stor2sweep = stor2sweep.intersect_range(domain_dup_sweep)
+        stor2sweep = move_non_sweep_to_par(
+                stor2sweep.intersect_range(domain_dup_sweep))
+
+        is_in_footprint = stor2sweep.domain().is_subset(
+                global_s2s_par_dom)
 
         if not invdesc.expands_footprint:
-            invdesc.is_in_footprint = stor2sweep.is_subset(global_stor2sweep)
+            invdesc.is_in_footprint = is_in_footprint
         else:
-            assert stor2sweep.domain().is_subset(global_stor2sweep.domain())
-
+            assert is_in_footprint
 
     return global_stor2sweep
 
@@ -193,7 +217,7 @@ def get_access_info(kernel, subst_name,
     # }}}
 
     stor2sweep = build_global_storage_to_sweep_map(
-            invocation_descriptors, domain_dup_sweep,
+            invocation_descriptors, dup_sweep_index, domain_dup_sweep,
             storage_axis_names, storage_axis_sources, prime_sweep_inames)
 
     storage_base_indices, storage_shape = compute_bounds(
@@ -281,13 +305,13 @@ def simplify_via_aff(expr):
 
 
 
-def precompute(kernel, subst_name, dtype, sweep_axes=[],
+def precompute(kernel, subst_name, dtype, sweep_inames=[],
         footprint_generators=None,
         storage_axes=None, new_storage_axis_names=None, storage_axis_to_tag={},
         default_tag="l.auto"):
     """Precompute the expression described in the substitution rule *subst_name*
     and store it in a temporary array. A precomputation needs two things to operate,
-    a list of *sweep_axes* (order irrelevant) and an ordered list of *storage_axes*
+    a list of *sweep_inames* (order irrelevant) and an ordered list of *storage_axes*
     (whose order will describe the axis ordering of the temporary array).
 
     *subst_name* may contain a period (".") to filter out a subset of the
@@ -308,11 +332,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
       within its arguments. A new, dedicated storage axis is allocated for
       such an axis.
 
-    * The axis is a formal argument name of the substitution rule.
-      This is equivalent to specifying *all* inames occurring within
-      the so-named formal argument at *all* usage sites.
-
-    :arg sweep_axes: A :class:`list` of inames and/or rule argument names to be swept.
+    :arg sweep_inames: A :class:`list` of inames and/or rule argument names to be swept.
     :arg storage_axes: A :class:`list` of inames and/or rule argument names/indices to be used as storage axes.
 
     If `storage_axes` is not specified, it defaults to the arrangement
@@ -323,6 +343,19 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
     eliminated.
     """
 
+    # {{{ check arguments
+
+    for iname in sweep_inames:
+        if iname not in kernel.all_inames():
+            raise RuntimeError("sweep iname '%s' is not a known iname"
+                    % iname)
+
+    if footprint_generators is not None:
+        if isinstance(footprint_generators, str):
+            footprint_generators = [footprint_generators]
+
+    # }}}
+
     from loopy.symbolic import SubstitutionCallbackMapper
 
     c_subst_name = subst_name.replace(".", "_")
@@ -334,27 +367,11 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
     subst = kernel.substitutions[subst_name]
     arg_names = subst.arguments
 
-    # {{{ gather up invocations
+    # {{{ create list of invocation descriptors
 
     invocation_descriptors = []
 
-    def gather_substs(expr, name, instance, args, rec):
-        if len(args) != len(subst.arguments):
-            raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
-                    % (subst_name, len(args), len(subst.arguments), ))
-
-        arg_deps = get_dependencies(args)
-        if not arg_deps <= kernel.all_inames():
-            raise RuntimeError("CSE arguments in '%s' do not consist "
-                    "exclusively of inames" % expr)
-
-        invocation_descriptors.append(
-                InvocationDescriptor(expr=expr, args=args,
-                    expands_footprint=footprint_generators is None))
-        return expr
-
-    from loopy.symbolic import SubstitutionCallbackMapper
-    scm = SubstitutionCallbackMapper([(subst_name, subst_instance)], gather_substs)
+    # {{{ process invocations in footprint generators
 
     if footprint_generators:
         for fpg in footprint_generators:
@@ -373,56 +390,79 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
 
             invocation_descriptors.append(
                     InvocationDescriptor(expr=fpg, args=args,
-                        expands_footprint=True))
+                        expands_footprint=True,
+                        from_subst_rule=None))
+
+    # }}}
+
+    # {{{ gather up invocations in kernel code
+
+    current_subst_rule_stack = []
 
     # We need to work on the fully expanded form of an expression.
     # To that end, instantiate a substitutor.
     from loopy.symbolic import ParametrizedSubstitutor
     rules_except_mine = kernel.substitutions.copy()
     del rules_except_mine[subst_name]
-    subst_expander = ParametrizedSubstitutor(rules_except_mine)
+    subst_expander = ParametrizedSubstitutor(rules_except_mine,
+            one_level=True)
 
-    for insn in kernel.instructions:
-        # We can't deal with invocations that involve other substitution's
-        # arguments. Therefore, fully expand each instruction and look at
-        # the invocations in subst_name occurring there.
+    def gather_substs(expr, name, instance, args, rec):
+        if subst_name != name:
+            if name in subst_expander.rules:
+                # We can't deal with invocations that involve other substitution's
+                # arguments. Therefore, fully expand each encountered substitution
+                # rule and look at the invocations of subst_name occurring in its
+                # body.
+
+                expanded_expr = subst_expander(expr)
+                current_subst_rule_stack.append(name)
+                result = rec(expanded_expr)
+                current_subst_rule_stack.pop()
+                return result
 
-        expanded_expr = subst_expander(insn.expression)
-        scm(expanded_expr)
+            else:
+                return None
 
-    if not invocation_descriptors:
-        raise RuntimeError("no invocations of '%s' found" % subst_name)
+        if subst_instance != instance:
+            # use fall-back identity mapper
+            return None
 
-    # }}}
+        if len(args) != len(subst.arguments):
+            raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
+                    % (subst_name, len(args), len(subst.arguments), ))
 
-    # {{{ deal with argument names as sweep axes
+        arg_deps = get_dependencies(args)
+        if not arg_deps <= kernel.all_inames():
+            raise RuntimeError("CSE arguments in '%s' do not consist "
+                    "exclusively of inames" % expr)
 
-    # An argument name as a sweep iname means that *all*
-    # inames contained in *all* uses of the rule will be
-    # made sweep inames.
+        if current_subst_rule_stack:
+            current_subst_rule = current_subst_rule_stack[-1]
+        else:
+            current_subst_rule = None
 
-    sweep_inames = set()
+        invocation_descriptors.append(
+                InvocationDescriptor(expr=expr, args=args,
+                    expands_footprint=footprint_generators is None,
+                    from_subst_rule=current_subst_rule))
+        return expr
 
-    for invdesc in invocation_descriptors:
-        if not invdesc.expands_footprint:
-            continue
+    from loopy.symbolic import SubstitutionCallbackMapper
+    scm = SubstitutionCallbackMapper(names_filter=None, func=gather_substs)
 
-        for swaxis in sweep_axes:
-            if isinstance(swaxis, int):
-                sweep_inames.update(
-                        get_dependencies(invdesc.args[swaxis]))
-            elif swaxis in subst.arguments:
-                arg_idx = subst.arguments.index(swaxis)
-                sweep_inames.update(
-                        get_dependencies(invdesc.args[arg_idx]))
-            else:
-                sweep_inames.add(swaxis)
+    for insn in kernel.instructions:
+        scm(insn.expression)
 
-    sweep_inames = list(sweep_inames)
-    del sweep_axes
+    if not invocation_descriptors:
+        raise RuntimeError("no invocations of '%s' found" % subst_name)
 
     # }}}
 
+    # }}}
+
+    sweep_inames = list(sweep_inames)
+
     # {{{ see if we need extra storage dimensions
 
     # find inames used in argument dependencies
@@ -457,8 +497,6 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
     storage_axis_name_to_tag = {}
 
     for i, saxis in enumerate(storage_axes):
-        new_name = None
-
         tag_lookup_saxis = saxis
 
         if saxis in subst.arguments:
@@ -509,7 +547,7 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
                             storage_axis_names, storage_axis_sources,
                             sweep_inames, invocation_descriptors)
 
-    # {{{ ensure convexity of new_domain
+    # {{{ try a few ways to get new_domain to be convex
 
     if len(new_domain.get_basic_sets()) > 1:
         hull_new_domain = new_domain.simple_hull()
@@ -587,17 +625,62 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
             left_unused_subst_rule_invocations[0] = True
             return expr
 
-        found = False
-        for invdesc in invocation_descriptors:
-            if expr == invdesc.expr:
-                found = True
-                break
+        # {{{ check if current use is in-footprint
 
-        if not invdesc.is_in_footprint:
-            left_unused_subst_rule_invocations[0] = True
-            return expr
+        if current_subst_rule is None:
+            # The current subsitution was *not* found inside another
+            # substitution rule. Try and dig up the corresponding invocation
+            # descriptor.
+
+            found = False
+            for invdesc in invocation_descriptors:
+                if expr == invdesc.expr:
+                    found = True
+                    break
 
-        assert found, expr
+            if not invdesc.is_in_footprint:
+                left_unused_subst_rule_invocations[0] = True
+                return expr
+
+            assert found, expr
+
+        else:
+            # The current subsitution *was* found inside another substitution
+            # rule. We can't dig up the corresponding invocation descriptor,
+            # because it was the result of expanding that outer substitution
+            # rule. But we do know what the current outer substitution rule is,
+            # and we can check if all uses within that rule were uniformly
+            # in-footprint. If so, we'll go ahead, otherwise we'll bomb out.
+
+            current_rule_invdescs_in_footprint = [
+                    invdesc.is_in_footprint
+                    for invdesc in invocation_descriptors
+                    if invdesc.from_subst_rule == current_subst_rule]
+
+            from pytools import all
+            all_in = all(current_rule_invdescs_in_footprint)
+            all_out = all(not b for b in current_rule_invdescs_in_footprint)
+
+            assert not (all_in and all_out)
+
+            if not (all_in or all_out):
+                raise RuntimeError("substitution '%s' (being precomputed) is used "
+                        "from within substitution '%s', but not all uses of "
+                        "'%s' within '%s' "
+                        "are uniformly within-footprint or outside of the footprint, "
+                        "making a unique replacement of '%s' impossible. Please expand "
+                        "'%s' and try again."
+                        % (subst_name, current_subst_rule,
+                            subst_name, current_subst_rule,
+                            subst_name, current_subst_rule))
+
+            if all_out:
+                left_unused_subst_rule_invocations[0] = True
+                return expr
+
+            assert all_in
+
+        # }}}
 
         if len(args) != len(subst.arguments):
             raise ValueError("invocation of '%s' with too few arguments"
@@ -628,15 +711,18 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
 
     new_insns = [compute_insn]
 
+    current_subst_rule = None
     sub_map = SubstitutionCallbackMapper([subst_name], do_substs)
     for insn in kernel.instructions:
         new_insn = insn.copy(expression=sub_map(insn.expression))
         new_insns.append(new_insn)
 
     # also catch uses of our rule in other substitution rules
-    new_substs = dict(
-            (s.name, s.copy(expression=sub_map(s.expression)))
-            for s in kernel.substitutions.itervalues())
+    new_substs = {}
+    for s in kernel.substitutions.itervalues():
+        current_subst_rule = s.name
+        new_substs[s.name] = s.copy(
+                expression=sub_map(s.expression))
 
     # If the subst above caught all uses of the subst rule, get rid of it.
     if not left_unused_subst_rule_invocations[0]:
@@ -658,5 +744,4 @@ def precompute(kernel, subst_name, dtype, sweep_axes=[],
 
 
 
-
 # vim: foldmethod=marker
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 5c86369a3..1b1388653 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -736,8 +736,8 @@ def adjust_local_temp_var_storage(kernel):
 
 
 def preprocess_kernel(kernel):
-    from loopy.subst import apply_subst
-    kernel = apply_subst(kernel)
+    from loopy.subst import expand_subst
+    kernel = expand_subst(kernel)
 
     kernel = realize_reduction(kernel)
 
diff --git a/loopy/subst.py b/loopy/subst.py
index 636252304..1ed9788ca 100644
--- a/loopy/subst.py
+++ b/loopy/subst.py
@@ -172,7 +172,7 @@ def extract_subst(kernel, subst_name, template, parameters):
 
 
 
-def apply_subst(kernel, subst_name=None):
+def expand_subst(kernel, subst_name=None):
     if subst_name is None:
         rules = kernel.substitutions
     else:
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index 2b1d58472..1ed6eb4fc 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -545,11 +545,17 @@ class SubstitutionCallbackMapper(IdentityMapper):
 # {{{ parametrized substitutor
 
 class ParametrizedSubstitutor(object):
-    def __init__(self, rules):
+    def __init__(self, rules, one_level=False):
         self.rules = rules
+        self.one_level = one_level
 
     def __call__(self, expr):
+        level = [0]
+
         def expand_if_known(expr, name, instance, args, rec):
+            if self.one_level and level[0] > 0:
+                return None
+
             rule = self.rules[name]
             if len(rule.arguments) != len(args):
                 raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
@@ -559,7 +565,11 @@ class ParametrizedSubstitutor(object):
             subst_map = SubstitutionMapper(make_subst_func(
                 dict(zip(rule.arguments, args))))
 
-            return rec(subst_map(rule.expression))
+            level[0] += 1
+            result = rec(subst_map(rule.expression))
+            level[0] -= 1
+
+            return result
 
         scm = SubstitutionCallbackMapper(self.rules.keys(), expand_if_known)
         return scm(expr)
diff --git a/test/test_nbody.py b/test/test_nbody.py
index fb2cee8c2..4f888ef09 100644
--- a/test/test_nbody.py
+++ b/test/test_nbody.py
@@ -44,6 +44,7 @@ def test_nbody(ctx_factory):
         return knl, []
 
     def variant_gpu(knl):
+        knl = lp.expand_subst(knl)
         knl = lp.split_dimension(knl, "i", 256,
                 outer_tag="g.0", inner_tag="l.0", slabs=(0,1))
         knl = lp.split_dimension(knl, "j", 256, slabs=(0,1))
diff --git a/test/test_sem_reagan.py b/test/test_sem_reagan.py
index 8d31024bb..bdccb112f 100644
--- a/test/test_sem_reagan.py
+++ b/test/test_sem_reagan.py
@@ -47,24 +47,28 @@ def test_tim2d(ctx_factory):
             # lp.ImageArg("D", dtype, shape=(n, n)),
             lp.ScalarArg("K", np.int32, approximately=1000),
             ],
-             name="semlap2D", assumptions="K>=1")
+            name="semlap2D", assumptions="K>=1")
 
     seq_knl = knl
 
     def variant_orig(knl):
+        knl = lp.tag_dimensions(knl, dict(i="l.0", j="l.1", e="g.0"))
+
         knl = lp.add_prefetch(knl, "D", ["m", "j", "i","o"])
         knl = lp.add_prefetch(knl, "u", ["i", "j",  "o"])
-        knl = lp.precompute(knl, "ur", np.float32, ["a", "b"])
-        knl = lp.precompute(knl, "us", np.float32, ["a", "b"])
-        knl = lp.split_dimension(knl, "e", 1, outer_tag="g.0")#, slabs=(0, 1))
 
-        knl = lp.tag_dimensions(knl, dict(i="l.0", j="l.1"))
+        knl = lp.precompute(knl, "ur", np.float32, ["m", "j"], "ur(m,j)")
+        knl = lp.precompute(knl, "us", np.float32, ["i", "m"], "us(i,m)")
+
+        knl = lp.add_prefetch(knl, "G")
+
+        knl = lp.precompute(knl, "Gux", np.float32, ["m", "j"], "Gux(m,j)")
+        knl = lp.precompute(knl, "Guy", np.float32, ["i", "m"], "Gux(i,m)")
+
         knl = lp.tag_dimensions(knl, dict(o="unr"))
         knl = lp.tag_dimensions(knl, dict(m="unr"))
 
-
-        # knl = lp.add_prefetch(knl, "G", [2,3], default_tag=None) # axis/argument indices on G
-        knl = lp.add_prefetch(knl, "G", [2,3]) # axis/argument indices on G
+        return knl
 
     def variant_prefetch(knl):
         knl = lp.precompute(knl, "ur", np.float32, ["a", "b"])
@@ -94,8 +98,8 @@ def test_tim2d(ctx_factory):
         knl = lp.precompute(knl, "Guy", np.float32, ["a", "b"])
         return knl
 
-    #for variant in [variant_orig]:
-    for variant in [variant_1]:
+    for variant in [variant_orig]:
+    #for variant in [variant_1]:
         kernel_gen = lp.generate_loop_schedules(variant(knl))
         kernel_gen = lp.check_kernels(kernel_gen, dict(K=1000))
 
@@ -103,7 +107,7 @@ def test_tim2d(ctx_factory):
         lp.auto_test_vs_ref(seq_knl, ctx, kernel_gen,
                 op_count=K*(n*n*n*2*2 + n*n*2*3 + n**3 * 2*2)/1e9,
                 op_label="GFlops",
-                parameters={"K": K})
+                parameters={"K": K}, print_ref_code=True)
 
 
 
-- 
GitLab