From eb18c8dbc0df14be2e123e1e27caf982fa80f5d7 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 5 Nov 2012 02:47:41 -0500
Subject: [PATCH] Properly deal with term rewriting in the presence of subst
 rules. Most tests pass.

---
 MEMO                      |  19 +-
 loopy/__init__.py         | 135 +++++++-------
 loopy/context_matching.py | 193 ++++++++++++++++++++
 loopy/creation.py         | 181 ++-----------------
 loopy/cse.py              | 365 +++++++++++++++++++-------------------
 loopy/kernel.py           | 129 +++++---------
 loopy/preprocess.py       |  22 ++-
 loopy/subst.py            |  30 +---
 loopy/symbolic.py         | 356 ++++++++++++++++++++-----------------
 test/test_linalg.py       |   1 -
 test/test_loopy.py        |  56 +++++-
 11 files changed, 794 insertions(+), 693 deletions(-)
 create mode 100644 loopy/context_matching.py

diff --git a/MEMO b/MEMO
index 5d69b66be..23865403a 100644
--- a/MEMO
+++ b/MEMO
@@ -41,15 +41,23 @@ Things to consider
 
 - Dependency on non-local global writes is ill-formed
 
+- No substitution rules allowed on lhs of insns
+
 To-do
 ^^^^^
 
-- Prohibit known variable names as subst rule arguments
-
 - Expose iname-duplicate-and-rename as a primitive.
 
 - Kernel fusion
 
+- ExpandingIdentityMapper
+  extract_subst -> needs WalkMapper
+  duplicate_inames
+  join_inames
+  padding
+  split_iname [DONE]
+  CSE [DONE]
+
 - Data implementation tags
   TODO initial bringup:
   - implemented_arg_info
@@ -62,11 +70,18 @@ To-do
   - vectorization
   - automatic copies
   - write_image()
+  - change_arg_to_image (test!)
+
+- Import SEM test
 
 - Make tests run on GPUs
 
 Fixes:
 
+- applied_iname_rewrites tracking for prefetch footprints isn't bulletproof
+  old inames may still be around, so the rewrite may or may not have to be
+  applied.
+
 - Group instructions by dependency/inames for scheduling, to
   increase sched. scalability
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 2f369e749..e06ffdbe3 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -39,6 +39,8 @@ from islpy import dim_type
 
 from pytools import MovedFunctionDeprecationWrapper
 
+from loopy.symbolic import ExpandingIdentityMapper
+
 
 
 
@@ -93,10 +95,46 @@ class infer_type:
 
 # {{{ split inames
 
+class _InameSplitter(ExpandingIdentityMapper):
+    def __init__(self, kernel, within,
+            split_iname, outer_iname, inner_iname, replacement_index):
+        ExpandingIdentityMapper.__init__(self,
+                kernel.substitutions, kernel.get_var_name_generator())
+
+        self.within = within
+
+        self.split_iname = split_iname
+        self.outer_iname = outer_iname
+        self.inner_iname = inner_iname
+
+        self.replacement_index = replacement_index
+
+    def map_reduction(self, expr, expn_state):
+        if self.split_iname in expr.inames and self.within(expn_state.stack):
+            new_inames = list(expr.inames)
+            new_inames.remove(self.split_iname)
+            new_inames.extend([self.outer_iname, self.inner_iname])
+
+            from loopy.symbolic import Reduction
+            return Reduction(expr.operation, tuple(new_inames),
+                        self.rec(expr.expr, expn_state))
+        else:
+            return ExpandingIdentityMapper.map_reduction(self, expr, expn_state)
+
+    def map_variable(self, expr, expn_state):
+        if expr.name == self.split_iname and self.within(expn_state.stack):
+            return self.replacement_index
+        else:
+            return ExpandingIdentityMapper.map_variable(self, expr, expn_state)
+
 def split_iname(kernel, split_iname, inner_length,
         outer_iname=None, inner_iname=None,
         outer_tag=None, inner_tag=None,
-        slabs=(0, 0), do_tagged_check=True):
+        slabs=(0, 0), do_tagged_check=True,
+        within=None):
+
+    from loopy.context_matching import parse_stack_match
+    within = parse_stack_match(within)
 
     existing_tag = kernel.iname_to_tag.get(split_iname)
     from loopy.kernel import ForceSequentialTag
@@ -154,21 +192,16 @@ def split_iname(kernel, split_iname, inner_length,
     outer = var(outer_iname)
     new_loop_index = inner + outer*inner_length
 
+    subst_map = {var(split_iname): new_loop_index}
+    applied_iname_rewrites.append(subst_map)
+
     # {{{ actually modify instructions
 
-    from loopy.symbolic import ReductionLoopSplitter
+    ins = _InameSplitter(kernel, within,
+            split_iname, outer_iname, inner_iname, new_loop_index)
 
-    rls = ReductionLoopSplitter(split_iname, outer_iname, inner_iname)
     new_insns = []
     for insn in kernel.instructions:
-        subst_map = {var(split_iname): new_loop_index}
-        applied_iname_rewrites.append(subst_map)
-
-        from loopy.symbolic import SubstitutionMapper
-        subst_mapper = SubstitutionMapper(subst_map.get)
-
-        new_expr = subst_mapper(rls(insn.expression))
-
         if split_iname in insn.forced_iname_deps:
             new_forced_iname_deps = (
                     (insn.forced_iname_deps.copy()
@@ -178,8 +211,8 @@ def split_iname(kernel, split_iname, inner_length,
             new_forced_iname_deps = insn.forced_iname_deps
 
         insn = insn.copy(
-                assignee=subst_mapper(insn.assignee),
-                expression=new_expr,
+                assignee=ins(insn.assignee, insn.id),
+                expression=ins(insn.expression, insn.id),
                 forced_iname_deps=new_forced_iname_deps)
 
         new_insns.append(insn)
@@ -188,10 +221,11 @@ def split_iname(kernel, split_iname, inner_length,
 
     iname_slab_increments = kernel.iname_slab_increments.copy()
     iname_slab_increments[outer_iname] = slabs
+
     result = (kernel
-            .map_expressions(subst_mapper, exclude_instructions=True)
             .copy(domains=new_domains,
                 iname_slab_increments=iname_slab_increments,
+                substitutions=ins.get_new_substitutions(),
                 instructions=new_insns,
                 applied_iname_rewrites=applied_iname_rewrites,
                 ))
@@ -382,9 +416,11 @@ def _add_kernel_axis(kernel, axis_name, start, stop, base_inames):
     return kernel.copy(domains=domch.get_domains_with(domain))
 
 def _process_footprint_subscripts(kernel, rule_name, sweep_inames,
-        footprint_subscripts, arg, newly_created_vars):
+        footprint_subscripts, arg):
     """Track applied iname rewrites, deal with slice specifiers ':'."""
 
+    name_gen = kernel.get_var_name_generator()
+
     from pymbolic.primitives import Variable
 
     if footprint_subscripts is None:
@@ -423,11 +459,9 @@ def _process_footprint_subscripts(kernel, rule_name, sweep_inames,
                     raise NotImplementedError("add_prefetch only "
                             "supports full slices")
 
-                axis_name = kernel.make_unique_var_name(
-                        based_on="%s_fetch_axis_%d" % (arg.name, axis_nr),
-                        extra_used_vars=newly_created_vars)
+                axis_name = name_gen(
+                        based_on="%s_fetch_axis_%d" % (arg.name, axis_nr))
 
-                newly_created_vars.add(axis_name)
                 kernel = _add_kernel_axis(kernel, axis_name, 0, arg.shape[axis_nr],
                         frozenset(sweep_inames) | fsub_dependencies)
                 sweep_inames = sweep_inames + [axis_name]
@@ -537,11 +571,11 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
     kernel, subst_use, sweep_inames, inames_to_be_removed = \
             _process_footprint_subscripts(
                     kernel,  rule_name, sweep_inames,
-                    footprint_subscripts, arg, newly_created_vars)
+                    footprint_subscripts, arg)
 
-    new_kernel = precompute(kernel, subst_use, arg.dtype, sweep_inames,
+    new_kernel = precompute(kernel, subst_use, sweep_inames,
             new_storage_axis_names=dim_arg_names,
-            default_tag=default_tag)
+            default_tag=default_tag, dtype=arg.dtype)
 
     # {{{ remove inames that were temporarily added by slice sweeps
 
@@ -571,49 +605,19 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
 
 # {{{ instruction processing
 
-class _IdMatch(object):
-    def __init__(self, value):
-        self.value = value
-
-class _ExactIdMatch(_IdMatch):
-    def __call__(self, insn):
-        return insn.id == self.value
-
-class _ReIdMatch:
-    def __call__(self, insn):
-        return self.value.match(insn.id) is not None
-
-def _parse_insn_match(insn_match):
-    import re
-    colon_idx = insn_match.find(":")
-    if colon_idx == -1:
-        return _ExactIdMatch(insn_match)
-
-    match_tp = insn_match[:colon_idx]
-    match_val = insn_match[colon_idx+1:]
-
-    if match_tp == "glob":
-        from fnmatch import translate
-        return _ReIdMatch(re.compile(translate(match_val)))
-    elif match_tp == "re":
-        return _ReIdMatch(re.compile(match_val))
-    else:
-        raise ValueError("match type '%s' not understood" % match_tp)
-
-
-
-
 def find_instructions(kernel, insn_match):
-    match = _parse_insn_match(insn_match)
-    return [insn for insn in kernel.instructions if match(insn)]
+    from loopy.context_matching import parse_id_match
+    match = parse_id_match(insn_match)
+    return [insn for insn in kernel.instructions if match(insn.id, None)]
 
 def map_instructions(kernel, insn_match, f):
-    match = _parse_insn_match(insn_match)
+    from loopy.context_matching import parse_id_match
+    match = parse_id_match(insn_match)
 
     new_insns = []
 
     for insn in kernel.instructions:
-        if match(insn):
+        if match(insn.id, None):
             new_insns.append(f(insn))
         else:
             new_insns.append(insn)
@@ -623,8 +627,8 @@ def map_instructions(kernel, insn_match, f):
 def set_instruction_priority(kernel, insn_match, priority):
     """Set the priority of instructions matching *insn_match* to *priority*.
 
-    *insn_match* may be an instruction id, a regular expression prefixed by `re:`,
-    or a file-name-style glob prefixed by `glob:`.
+    *insn_match* may be any instruction id match understood by
+    :func:`loopy.context_matching.parse_id_match`.
     """
 
     def set_prio(insn): return insn.copy(priority=priority)
@@ -634,8 +638,8 @@ def add_dependency(kernel, insn_match, dependency):
     """Add the instruction dependency *dependency* to the instructions matched
     by *insn_match*.
 
-    *insn_match* may be an instruction id, a regular expression prefixed by `re:`,
-    or a file-name-style glob prefixed by `glob:`.
+    *insn_match* may be any instruction id match understood by
+    :func:`loopy.context_matching.parse_id_match`.
     """
 
     def add_dep(insn): return insn.copy(insn_deps=insn.insn_deps + [dependency])
@@ -659,6 +663,13 @@ def change_arg_to_image(knl, name):
 
 # }}}
 
+# {{{ duplicate inames
+
+def duplicate_inames(knl, inames):
+    pass
+
+# }}}
+
 
 
 
diff --git a/loopy/context_matching.py b/loopy/context_matching.py
new file mode 100644
index 000000000..51cc8a5fd
--- /dev/null
+++ b/loopy/context_matching.py
@@ -0,0 +1,193 @@
+"""Matching functionality for instruction ids and subsitution
+rule invocations stacks."""
+
+from __future__ import division
+
+__copyright__ = "Copyright (C) 2012 Andreas Kloeckner"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+
+
+
+# {{{ id match objects
+
+class AllMatch(object):
+    def __call__(self, identifier, tag):
+        return True
+
+class RegexIdentifierMatch(object):
+    def __init__(self, id_re, tag_re=None):
+        self.id_re = id_re
+        self.tag_re = tag_re
+
+    def __call__(self, identifier, tag):
+        if self.tag_re is None:
+            return self.id_re.match(identifier) is not None
+        else:
+            if tag is None:
+                tag = ""
+
+            return (
+                    self.id_re.match(identifier) is not None
+                    and self.tag_re.match(tag) is not None)
+
+class AlternativeMatch(object):
+    def __init__(self, matches):
+        self.matches = matches
+
+    def __call__(self, identifier, tag):
+        from pytools import any
+        return any(
+                mtch(identifier, tag) for mtch in self.matches)
+
+# }}}
+
+# {{{ single id match parsing
+
+def parse_id_match(id_matches):
+    """Syntax examples:
+
+    my_insn
+    compute_*
+    fetch*$first
+    fetch*$first,store*$first
+
+    Alternatively, a list of *(name_glob, tag_glob)* tuples.
+    """
+
+    if id_matches is None:
+        return AllMatch()
+
+    if isinstance(id_matches, str):
+        id_matches = id_matches.split(",")
+
+    if len(id_matches) > 1:
+        return AlternativeMatch(parse_id_match(im) for im in id_matches)
+
+    if len(id_matches) == 0:
+        return AllMatch()
+
+    id_match, = id_matches
+    del id_matches
+
+    def re_from_glob(s):
+        import re
+        from fnmatch import translate
+        return re.compile(translate(s.strip()))
+
+    if not isinstance(id_match, tuple):
+        components = id_match.split("$")
+
+    if len(components) == 1:
+        return RegexIdentifierMatch(re_from_glob(components[0]))
+    elif len(components) == 2:
+        return RegexIdentifierMatch(
+                re_from_glob(components[0]),
+                re_from_glob(components[1]))
+    else:
+        raise RuntimeError("too many (%d) $-separated components in id match"
+                % len(components))
+
+# }}}
+
+# {{{ stack match objects
+
+# these match from the tail of the stack
+
+class StackMatchBase(object):
+    pass
+
+class AllStackMatch(StackMatchBase):
+    def __call__(self, stack):
+        return True
+
+class StackIdMatch(StackMatchBase):
+    def __init__(self, id_match, up_match):
+        self.id_match = id_match
+        self.up_match = up_match
+
+    def __call__(self, stack):
+        if not stack:
+            return False
+
+        last = stack[-1]
+        if not self.id_match(*last):
+            return False
+
+        if self.up_match is None:
+            return True
+        else:
+            return self.up_match(stack[:-1])
+
+class StackWildcardMatch(StackMatchBase):
+    def __init__(self, up_match):
+        self.up_match = up_match
+
+    def __call__(self, stack):
+        if self.up_match is None:
+            return True
+
+        n = len(stack)
+
+        for i in xrange(n):
+            if self.up_match(stack[:-i]):
+                return True
+
+        return False
+
+# }}}
+
+# {{{ stack match parsing
+
+def parse_stack_match(smatch):
+    """Syntax example::
+
+        lowest < next < ... < highest
+
+    where `lowest` is necessarily the bottom of the stack. There is currently
+    no way to anchor to the top of the stack.
+    """
+
+    if isinstance(smatch, StackMatchBase):
+        return smatch
+
+    match = AllStackMatch()
+
+    if smatch is None:
+        return match
+
+    components = smatch.split("<")
+
+    for comp in components[::-1]:
+        comp = comp.strip()
+        if comp == "...":
+            match = StackWildcardMatch(match)
+        else:
+            match = StackIdMatch(parse_id_match(comp), match)
+
+    return match
+
+# }}}
+
+
+
+# vim: foldmethod=marker
diff --git a/loopy/creation.py b/loopy/creation.py
index 55bc29d44..5c16d24cd 100644
--- a/loopy/creation.py
+++ b/loopy/creation.py
@@ -247,172 +247,31 @@ def create_temporaries(knl):
 
 # }}}
 
-# {{{ reduction iname duplication
+# {{{ check for reduction iname duplication
 
-def duplicate_reduction_inames(kernel):
+def check_for_reduction_inames_duplication_requests(kernel):
 
     # {{{ helper function
 
-    newly_created_vars = set()
-
-    def duplicate_reduction_inames(reduction_expr, rec):
-        child = rec(reduction_expr.expr)
-        new_red_inames = []
-        did_something = False
-
+    def check_reduction_inames(reduction_expr, rec):
         for iname in reduction_expr.inames:
             if iname.startswith("@"):
-                new_iname = kernel.make_unique_var_name(iname[1:]+"_"+name_base,
-                        newly_created_vars)
-
-                old_inames.append(iname.lstrip("@"))
-                new_inames.append(new_iname)
-                newly_created_vars.add(new_iname)
-                new_red_inames.append(new_iname)
-                did_something = True
-            else:
-                new_red_inames.append(iname)
-
-        if did_something:
-            from loopy.symbolic import SubstitutionMapper
-            from pymbolic.mapper.substitutor import make_subst_func
-            from pymbolic import var
-
-            subst_dict = dict(
-                    (old_iname, var(new_iname))
-                    for old_iname, new_iname in zip(
-                        reduction_expr.untagged_inames, new_red_inames))
-            subst_map = SubstitutionMapper(make_subst_func(subst_dict))
-
-            child = subst_map(child)
-
-        from loopy.symbolic import Reduction
-        return Reduction(
-                operation=reduction_expr.operation,
-                inames=tuple(new_red_inames),
-                expr=child)
+                raise RuntimeError("Reduction iname duplication with '@' is no "
+                        "longer supported. Use loopy.duplicate_inames instead.")
 
     # }}}
 
-    from loopy.symbolic import ReductionCallbackMapper
-    from loopy.isl_helpers import duplicate_axes
-
-    new_domains = kernel.domains
-    new_insns = []
-
-    new_iname_to_tag = kernel.iname_to_tag.copy()
 
+    from loopy.symbolic import ReductionCallbackMapper
+    rcm = ReductionCallbackMapper(check_reduction_inames)
     for insn in kernel.instructions:
-        old_inames = []
-        new_inames = []
-        name_base = insn.id
-
-        new_insns.append(insn.copy(
-            expression=ReductionCallbackMapper(duplicate_reduction_inames)
-            (insn.expression)))
-
-        for old, new in zip(old_inames, new_inames):
-            new_domains = duplicate_axes(new_domains, [old], [new])
-            if old in kernel.iname_to_tag:
-                new_iname_to_tag[new] = kernel.iname_to_tag[old]
+        rcm(insn.expression)
 
-    new_substs = {}
     for sub_name, sub_rule in kernel.substitutions.iteritems():
-        old_inames = []
-        new_inames = []
-        name_base = sub_name
-
-        new_substs[sub_name] = sub_rule.copy(
-                expression=ReductionCallbackMapper(duplicate_reduction_inames)
-                (sub_rule.expression))
-
-        for old, new in zip(old_inames, new_inames):
-            new_domains = duplicate_axes(new_domains, [old], [new])
-            if old in kernel.iname_to_tag:
-                new_iname_to_tag[new] = kernel.iname_to_tag[old]
-
-    return kernel.copy(
-            instructions=new_insns,
-            substitutions=new_substs,
-            domains=new_domains,
-            iname_to_tag=new_iname_to_tag)
+        rcm(sub_rule.expression)
 
 # }}}
 
-# {{{ duplicate inames
-
-def duplicate_inames(knl):
-    new_insns = []
-    new_domains = knl.domains
-    new_iname_to_tag = knl.iname_to_tag.copy()
-
-    newly_created_vars = set()
-
-    for insn in knl.instructions:
-        if insn.duplicate_inames_and_tags:
-            insn_dup_iname_to_tag = dict(insn.duplicate_inames_and_tags)
-
-            if not set(insn_dup_iname_to_tag.keys()) <= knl.all_inames():
-                raise ValueError("In instruction '%s': "
-                        "cannot duplicate inames '%s'--"
-                        "they don't exist" % (
-                            insn.id,
-                            ",".join(
-                                set(insn_dup_iname_to_tag.keys())-knl.all_inames())))
-
-            # {{{ duplicate non-reduction inames
-
-            reduction_inames = insn.reduction_inames()
-
-            inames_to_duplicate = [iname
-                    for iname, tag in insn.duplicate_inames_and_tags
-                    if iname not in reduction_inames]
-
-            new_inames = [
-                    knl.make_unique_var_name(
-                        based_on=iname+"_"+insn.id,
-                        extra_used_vars=newly_created_vars)
-                    for iname in inames_to_duplicate]
-
-            for old_iname, new_iname in zip(inames_to_duplicate, new_inames):
-                new_tag = insn_dup_iname_to_tag[old_iname]
-                new_iname_to_tag[new_iname] = new_tag
-
-            newly_created_vars.update(new_inames)
-
-            from loopy.isl_helpers import duplicate_axes
-            new_domains = duplicate_axes(new_domains, inames_to_duplicate, new_inames)
-
-            from loopy.symbolic import SubstitutionMapper
-            from pymbolic.mapper.substitutor import make_subst_func
-            from pymbolic import var
-            old_to_new = dict(
-                    (old_iname, var(new_iname))
-                    for old_iname, new_iname in zip(inames_to_duplicate, new_inames))
-            subst_map = SubstitutionMapper(make_subst_func(old_to_new))
-            new_expression = subst_map(insn.expression)
-
-            # }}}
-
-            if len(inames_to_duplicate) < len(insn.duplicate_inames_and_tags):
-                raise RuntimeError("cannot use [|...] syntax to rename reduction "
-                        "inames")
-
-            insn = insn.copy(
-                    assignee=subst_map(insn.assignee),
-                    expression=new_expression,
-                    forced_iname_deps=set(
-                        old_to_new.get(iname, iname) for iname in insn.forced_iname_deps),
-                    duplicate_inames_and_tags=[])
-
-        new_insns.append(insn)
-
-    return knl.copy(
-            instructions=new_insns,
-            domains=new_domains,
-            iname_to_tag=new_iname_to_tag)
-# }}}
-
 # {{{ kernel creation top-level
 
 def make_kernel(*args, **kwargs):
@@ -430,30 +289,11 @@ def make_kernel(*args, **kwargs):
                     iname_to_tag_requests=[])
 
     check_for_nonexistent_iname_deps(knl)
+    check_for_reduction_inames_duplication_requests(knl)
 
-    knl = duplicate_reduction_inames(knl)
-
-    # -------------------------------------------------------------------------
-    # Ordering dependency:
-    # -------------------------------------------------------------------------
-    # Must duplicate reduction inames before tagging reduction inames as
-    # sequential because otherwise the latter operation will run into @iname
-    # (i.e. duplication) markers and not understand them.
-    # -------------------------------------------------------------------------
 
     knl = tag_reduction_inames_as_sequential(knl)
-
     knl = create_temporaries(knl)
-    knl = duplicate_inames(knl)
-
-    # -------------------------------------------------------------------------
-    # Ordering dependency:
-    # -------------------------------------------------------------------------
-    # Must duplicate inames before expanding CSEs, otherwise inames within the
-    # scope of duplication might be CSE'd out to a different instruction and
-    # never be found by duplication.
-    # -------------------------------------------------------------------------
-
     knl = expand_cses(knl)
 
     # -------------------------------------------------------------------------
@@ -462,6 +302,7 @@ def make_kernel(*args, **kwargs):
     # Must create temporary before checking for writes to temporary variables
     # that are domain parameters.
     # -------------------------------------------------------------------------
+
     check_for_multiple_writes_to_loop_bounds(knl)
     check_for_duplicate_names(knl)
     check_written_variable_names(knl)
diff --git a/loopy/cse.py b/loopy/cse.py
index e5edaa10c..48ef6e356 100644
--- a/loopy/cse.py
+++ b/loopy/cse.py
@@ -27,7 +27,8 @@ THE SOFTWARE.
 
 import islpy as isl
 from islpy import dim_type
-from loopy.symbolic import get_dependencies, SubstitutionMapper
+from loopy.symbolic import (get_dependencies, SubstitutionMapper,
+        ExpandingIdentityMapper)
 from pymbolic.mapper.substitutor import make_subst_func
 import numpy as np
 
@@ -39,16 +40,13 @@ from pymbolic import var
 
 class InvocationDescriptor(Record):
     __slots__ = [
-            "expr",
             "args",
             "expands_footprint",
             "is_in_footprint",
 
-            # Record from which substitution rule this invocation of the rule
-            # being precomputed originated. If all invocations end up being
-            # in-footprint, then the replacement with the prefetch can be made
-            # within the rule.
-            "from_subst_rule"
+            # Remember where the invocation happened, in terms of the expansion
+            # call stack.
+            "expansion_stack",
             ]
 
 
@@ -379,9 +377,161 @@ def simplify_via_aff(expr):
 
 
 
-def precompute(kernel, subst_use, dtype, sweep_inames=[],
+class InvocationGatherer(ExpandingIdentityMapper):
+    def __init__(self, kernel, subst_name, subst_tag, within):
+        ExpandingIdentityMapper.__init__(self,
+                kernel.substitutions, kernel.get_var_name_generator())
+
+        from loopy.symbolic import ParametrizedSubstitutor
+        self.subst_expander = ParametrizedSubstitutor(
+                kernel.substitutions)
+
+        self.kernel = kernel
+        self.subst_name = subst_name
+        self.subst_tag = subst_tag
+        self.within = within
+
+        self.invocation_descriptors = []
+
+    def map_substitution(self, name, tag, arguments, expn_state):
+        process_me = name == self.subst_name
+
+        if self.subst_tag is not None and self.subst_tag != tag:
+            process_me = False
+
+        process_me = process_me and self.within(expn_state.stack)
+
+        if not process_me:
+            return ExpandingIdentityMapper.map_substitution(
+                    self, name, tag, arguments, expn_state)
+
+        rule = self.old_subst_rules[name]
+        arg_context = self.make_new_arg_context(
+                    name, rule.arguments, arguments, expn_state.arg_context)
+
+        arg_deps = set()
+        for arg_val in arg_context.itervalues():
+            arg_deps = (arg_deps
+                    | get_dependencies(self.subst_expander(arg_val, insn_id=None)))
+
+        if not arg_deps <= self.kernel.all_inames():
+            from warnings import warn
+            warn("Precompute arguments in '%s(%s)' do not consist exclusively "
+                    "of inames and constants--specifically, these are "
+                    "not inames: %s. Ignoring." % (
+                        name,
+                        ", ".join(str(arg) for arg in arguments),
+                        ", ".join(arg_deps - self.kernel.all_inames()),
+                        ))
+
+            return ExpandingIdentityMapper.map_substitution(
+                    self, name, tag, arguments, expn_state)
+
+        self.invocation_descriptors.append(
+                InvocationDescriptor(
+                    args=[arg_context[arg_name] for arg_name in rule.arguments],
+                    expansion_stack=expn_state.stack))
+
+        return 0 # exact value irrelevant
+
+
+
+
+class InvocationReplacer(ExpandingIdentityMapper):
+    def __init__(self, kernel, subst_name, subst_tag, within,
+            invocation_descriptors,
+            storage_axis_names, storage_axis_sources,
+            storage_base_indices, non1_storage_axis_names,
+            target_var_name):
+        ExpandingIdentityMapper.__init__(self,
+                kernel.substitutions, kernel.get_var_name_generator())
+
+        from loopy.symbolic import ParametrizedSubstitutor
+        self.subst_expander = ParametrizedSubstitutor(
+                kernel.substitutions, kernel.get_var_name_generator())
+
+        self.kernel = kernel
+        self.subst_name = subst_name
+        self.subst_tag = subst_tag
+        self.within = within
+
+        self.invocation_descriptors = invocation_descriptors
+
+        self.storage_axis_names = storage_axis_names
+        self.storage_axis_sources = storage_axis_sources
+        self.storage_base_indices = storage_base_indices
+        self.non1_storage_axis_names = non1_storage_axis_names
+
+        self.target_var_name = target_var_name
+
+    def map_substitution(self, name, tag, arguments, expn_state):
+        process_me = name == self.subst_name
+
+        if self.subst_tag is not None and self.subst_tag != tag:
+            process_me = False
+
+        process_me = process_me and self.within(expn_state.stack)
+
+        # {{{ find matching invocation descriptor
+
+        rule = self.old_subst_rules[name]
+        arg_context = self.make_new_arg_context(
+                    name, rule.arguments, arguments, expn_state.arg_context)
+        args = [arg_context[arg_name] for arg_name in rule.arguments]
+
+        if not process_me:
+            return ExpandingIdentityMapper.map_substitution(
+                    self, name, tag, arguments, expn_state)
+
+        matching_invdesc = None
+        for invdesc in self.invocation_descriptors:
+            if invdesc.args == args and expn_state.stack:
+                # Could be more than one, that's fine.
+                matching_invdesc = invdesc
+                break
+
+        assert matching_invdesc is not None
+
+        invdesc = matching_invdesc
+        del matching_invdesc
+
+        # }}}
+
+        if not invdesc.is_in_footprint:
+            return ExpandingIdentityMapper.map_substitution(
+                    self, name, tag, arguments, expn_state)
+
+        assert len(arguments) == len(rule.arguments)
+
+        stor_subscript = []
+        for sax_name, sax_source, sax_base_idx in zip(
+                self.storage_axis_names,
+                self.storage_axis_sources, 
+                self.storage_base_indices):
+            if sax_name not in self.non1_storage_axis_names:
+                continue
+
+            if isinstance(sax_source, int):
+                # an argument
+                ax_index = arguments[sax_source]
+            else:
+                # an iname
+                ax_index = var(sax_source)
+
+            ax_index = simplify_via_aff(ax_index - sax_base_idx)
+            stor_subscript.append(ax_index)
+
+        new_outer_expr = var(self.target_var_name)
+        if stor_subscript:
+            new_outer_expr = new_outer_expr[tuple(stor_subscript)]
+
+        return new_outer_expr
+        # can't possibly be nested, don't recurse
+
+
+def precompute(kernel, subst_use, sweep_inames=[], within=None,
         storage_axes=None, new_storage_axis_names=None, storage_axis_to_tag={},
-        default_tag="l.auto"):
+        default_tag="l.auto", dtype=None):
     """Precompute the expression described in the substitution rule determined by
     *subst_use* and store it in a temporary array. A precomputation needs two
     things to operate, a list of *sweep_inames* (order irrelevant) and an
@@ -426,6 +576,7 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
 
     :arg sweep_inames: A :class:`list` of inames and/or rule argument names to be swept.
     :arg storage_axes: A :class:`list` of inames and/or rule argument names/indices to be used as storage axes.
+    :arg within: a stack match as understood by :func:`loopy.context_matching.parse_stack_match`.
 
     If `storage_axes` is not specified, it defaults to the arrangement
     `<direct sweep axes><arguments>` with the direct sweep axes being the
@@ -486,6 +637,13 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
                 raise ValueError("not all uses in subst_use agree "
                         "on rule name and tag")
 
+    from loopy.context_matching import parse_stack_match
+    within = parse_stack_match(within)
+
+    from loopy import infer_type
+    if dtype is None:
+        dtype = infer_type
+
     # }}}
 
     # {{{ process invocations in footprint generators, start invocation_descriptors
@@ -504,9 +662,9 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
                         "be substitution rule invocation")
 
             invocation_descriptors.append(
-                    InvocationDescriptor(expr=fpg, args=args,
+                    InvocationDescriptor(args=args,
                         expands_footprint=True,
-                        from_subst_rule=None))
+                        ))
 
     # }}}
 
@@ -520,63 +678,14 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
 
     # {{{ gather up invocations in kernel code, finish invocation_descriptors
 
-    current_subst_rule_stack = []
-
-    # We need to work on the fully expanded form of an expression.
-    # To that end, instantiate a substitutor.
-    from loopy.symbolic import ParametrizedSubstitutor
-    rules_except_mine = kernel.substitutions.copy()
-    del rules_except_mine[subst_name]
-    subst_expander = ParametrizedSubstitutor(rules_except_mine,
-            one_level=True)
-
-    def gather_substs(expr, name, tag, args, rec):
-        if subst_name != name:
-            if name in subst_expander.rules:
-                # We can't deal with invocations that involve other substitution's
-                # arguments. Therefore, fully expand each encountered substitution
-                # rule and look at the invocations of subst_name occurring in its
-                # body.
-
-                expanded_expr = subst_expander(expr)
-                current_subst_rule_stack.append(name)
-                result = rec(expanded_expr)
-                current_subst_rule_stack.pop()
-                return result
-
-            else:
-                return None
-
-        if subst_tag is not None and subst_tag != tag:
-            # use fall-back identity mapper
-            return None
-
-        if len(args) != len(subst.arguments):
-            raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
-                    % (subst_name, len(args), len(subst.arguments), ))
-
-        arg_deps = get_dependencies(args)
-        if not arg_deps <= kernel.all_inames():
-            raise RuntimeError("CSE arguments in '%s' do not consist "
-                    "exclusively of inames" % expr)
+    invg = InvocationGatherer(kernel, subst_name, subst_tag, within)
 
-        if current_subst_rule_stack:
-            current_subst_rule = current_subst_rule_stack[-1]
-        else:
-            current_subst_rule = None
+    for insn in kernel.instructions:
+        invg(insn.expression, insn.id)
 
+    for invdesc in invg.invocation_descriptors:
         invocation_descriptors.append(
-                InvocationDescriptor(expr=expr, args=args,
-                    expands_footprint=footprint_generators is None,
-                    from_subst_rule=current_subst_rule))
-
-        return expr
-
-    from loopy.symbolic import SubstitutionCallbackMapper
-    scm = SubstitutionCallbackMapper(names_filter=None, func=gather_substs)
-
-    for insn in kernel.instructions:
-        scm(insn.expression)
+                invdesc.copy(expands_footprint=footprint_generators is None))
 
     if not invocation_descriptors:
         raise RuntimeError("no invocations of '%s' found" % subst_name)
@@ -608,7 +717,8 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
     from loopy.symbolic import ParametrizedSubstitutor
     submap = ParametrizedSubstitutor(kernel.substitutions)
 
-    value_inames = get_dependencies(submap(subst.expression)) & kernel.all_inames()
+    value_inames = get_dependencies(
+            submap(subst.expression, insn_id=None)) & kernel.all_inames()
     if value_inames - expanding_usage_arg_deps < extra_storage_axes:
         raise RuntimeError("unreferenced sweep inames specified: "
                 + ", ".join(extra_storage_axes - value_inames - expanding_usage_arg_deps))
@@ -736,121 +846,13 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
 
     # {{{ substitute rule into expressions in kernel (if within footprint)
 
-    left_unused_subst_rule_invocations = [False]
-
-    def do_substs(expr, name, tag, args, rec):
-        if tag != subst_tag:
-            left_unused_subst_rule_invocations[0] = True
-            return expr
-
-        # {{{ check if current use is in-footprint
-
-        if current_subst_rule is None:
-            # The current subsitution was *not* found inside another
-            # substitution rule. Try and dig up the corresponding invocation
-            # descriptor.
-
-            found = False
-            for invdesc in invocation_descriptors:
-                if expr == invdesc.expr:
-                    found = True
-                    break
-
-            if footprint_generators is None:
-                # We only have a right to find the expression if the
-                # invocation descriptors if they were generated by a scan
-                # of the code in the first place. If the user gave us
-                # the footprint generators, that isn't true.
-
-                assert found, expr
-
-            if not found or not invdesc.is_in_footprint:
-                left_unused_subst_rule_invocations[0] = True
-                return expr
-
-        else:
-            # The current subsitution *was* found inside another substitution
-            # rule. We can't dig up the corresponding invocation descriptor,
-            # because it was the result of expanding that outer substitution
-            # rule. But we do know what the current outer substitution rule is,
-            # and we can check if all uses within that rule were uniformly
-            # in-footprint. If so, we'll go ahead, otherwise we'll bomb out.
-
-            current_rule_invdescs_in_footprint = [
-                    invdesc.is_in_footprint
-                    for invdesc in invocation_descriptors
-                    if invdesc.from_subst_rule == current_subst_rule]
-
-            from pytools import all
-            all_in = all(current_rule_invdescs_in_footprint)
-            all_out = all(not b for b in current_rule_invdescs_in_footprint)
-
-            assert not (all_in and all_out)
-
-            if not (all_in or all_out):
-                raise RuntimeError("substitution '%s' (being precomputed) is used "
-                        "from within substitution '%s', but not all uses of "
-                        "'%s' within '%s' "
-                        "are uniformly within-footprint or outside of the footprint, "
-                        "making a unique replacement of '%s' impossible. Please expand "
-                        "'%s' and try again."
-                        % (subst_name, current_subst_rule,
-                            subst_name, current_subst_rule,
-                            subst_name, current_subst_rule))
-
-            if all_out:
-                left_unused_subst_rule_invocations[0] = True
-                return expr
-
-            assert all_in
-
-        # }}}
-
-        if len(args) != len(subst.arguments):
-            raise ValueError("invocation of '%s' with too few arguments"
-                    % name)
-
-        stor_subscript = []
-        for sax_name, sax_source, sax_base_idx in zip(
-                storage_axis_names, storage_axis_sources, storage_base_indices):
-            if sax_name not in non1_storage_axis_names:
-                continue
-
-            if isinstance(sax_source, int):
-                # an argument
-                ax_index = args[sax_source]
-            else:
-                # an iname
-                ax_index = var(sax_source)
-
-            ax_index = simplify_via_aff(ax_index - sax_base_idx)
-            stor_subscript.append(ax_index)
-
-        new_outer_expr = var(target_var_name)
-        if stor_subscript:
-            new_outer_expr = new_outer_expr[tuple(stor_subscript)]
-
-        return new_outer_expr
-        # can't possibly be nested, don't recurse
-
-    new_insns = [compute_insn]
-
-    current_subst_rule = None
-    sub_map = SubstitutionCallbackMapper([subst_name], do_substs)
-    for insn in kernel.instructions:
-        new_insn = insn.copy(expression=sub_map(insn.expression))
-        new_insns.append(new_insn)
-
-    # also catch uses of our rule in other substitution rules
-    new_substs = {}
-    for s in kernel.substitutions.itervalues():
-        current_subst_rule = s.name
-        new_substs[s.name] = s.copy(
-                expression=sub_map(s.expression))
+    invr = InvocationReplacer(kernel, subst_name, subst_tag, within,
+            invocation_descriptors,
+            storage_axis_names, storage_axis_sources,
+            storage_base_indices, non1_storage_axis_names,
+            target_var_name)
 
-    # If the subst above caught all uses of the subst rule, get rid of it.
-    if not left_unused_subst_rule_invocations[0]:
-        del new_substs[subst_name]
+    kernel = invr.map_kernel(kernel)
 
     # }}}
 
@@ -872,8 +874,7 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
 
     result =  kernel.copy(
             domains=domch.get_domains_with(new_domain),
-            instructions=new_insns,
-            substitutions=new_substs,
+            instructions=[compute_insn] + kernel.instructions,
             temporary_variables=new_temporary_variables)
 
     from loopy import tag_inames
diff --git a/loopy/kernel.py b/loopy/kernel.py
index d9fd7c274..014a4f6e3 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -308,6 +308,8 @@ class SubstitutionRule(Record):
     """
 
     def __init__(self, name, arguments, expression):
+        assert isinstance(arguments, tuple)
+
         Record.__init__(self,
                 name=name, arguments=arguments, expression=expression)
 
@@ -343,15 +345,12 @@ class Instruction(Record):
 
     :ivar temp_var_type: if not None, a type that will be assigned to the new temporary variable
         created from the assignee
-    :ivar duplicate_inames_and_tags: a list of inames used in the instruction that will be duplicated onto
-        different inames.
     """
     def __init__(self,
             id, assignee, expression,
             forced_iname_deps=frozenset(), insn_deps=set(), boostable=None,
             boostable_into=None,
-            temp_var_type=None, duplicate_inames_and_tags=[],
-            priority=0):
+            temp_var_type=None, priority=0):
 
         from loopy.symbolic import parse
         if isinstance(assignee, str):
@@ -368,14 +367,13 @@ class Instruction(Record):
                 insn_deps=insn_deps, boostable=boostable,
                 boostable_into=boostable_into,
                 temp_var_type=temp_var_type,
-                duplicate_inames_and_tags=duplicate_inames_and_tags,
                 priority=priority)
 
     @memoize_method
     def reduction_inames(self):
         def map_reduction(expr, rec):
             rec(expr.expr)
-            for iname in expr.untagged_inames:
+            for iname in expr.inames:
                 result.add(iname)
 
         from loopy.symbolic import ReductionCallbackMapper
@@ -631,6 +629,29 @@ def _generate_unique_possibilities(prefix):
         yield "%s_%d" % (prefix, try_num)
         try_num += 1
 
+class _UniqueNameGenerator:
+    def __init__(self, existing_names):
+        self.existing_names = existing_names.copy()
+
+    def is_name_conflicting(self, name):
+        return name in self.existing_names
+
+    def add_name(self, name):
+        assert name not in self.existing_names
+        self.existing_names.add(name)
+
+    def add_names(self, names):
+        assert not frozenset(names) & self.existing_names
+        self.existing_names.update(names)
+
+    def __call__(self, based_on="var"):
+        for var_name in _generate_unique_possibilities(based_on):
+            if not self.is_name_conflicting(var_name):
+                break
+
+        self.existing_names.add(var_name)
+        return var_name
+
 _IDENTIFIER_RE = re.compile(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\b")
 
 def _gather_identifiers(s):
@@ -783,13 +804,7 @@ class LoopKernel(Record):
 
         # {{{ parse instructions
 
-        INAME_ENTRY_RE = re.compile(
-                r"^\s*(?P<iname>\w+)\s*(?:\:\s*(?P<tag>[\w.]+))?\s*$")
         INSN_RE = re.compile(
-                "\s*(?:\["
-                    "(?P<iname_deps_and_tags>[\s\w,:.]*)"
-                    "(?:\|(?P<duplicate_inames_and_tags>[\s\w,:.]*))?"
-                "\])?"
                 "\s*(?:\<(?P<temp_var_type>.*?)\>)?"
                 "\s*(?P<lhs>.+?)\s*(?<!\:)=\s*(?P<rhs>.+?)"
                 "\s*?(?:\{(?P<options>[\s\w=,:]+)\}\s*)?$"
@@ -798,32 +813,6 @@ class LoopKernel(Record):
                 r"^\s*(?P<lhs>.+?)\s*:=\s*(?P<rhs>.+)\s*$"
                 )
 
-        def parse_iname_and_tag_list(s):
-            dup_entries = [
-                    dep.strip() for dep in s.split(",")]
-            result = []
-            for entry in dup_entries:
-                if not entry:
-                    continue
-
-                entry_match = INAME_ENTRY_RE.match(entry)
-                if entry_match is None:
-                    raise RuntimeError(
-                            "could not parse iname:tag entry '%s'"
-                            % entry)
-
-                groups = entry_match.groupdict()
-                iname = groups["iname"]
-                assert iname
-
-                tag = None
-                if groups["tag"] is not None:
-                    tag = parse_tag(groups["tag"])
-
-                result.append((iname, tag))
-
-            return result
-
         def parse_insn(insn):
             insn_match = INSN_RE.match(insn)
             subst_match = SUBST_RE.match(insn)
@@ -870,20 +859,6 @@ class LoopKernel(Record):
                             raise ValueError("unrecognized instruction option '%s'"
                                     % opt_key)
 
-                if groups["iname_deps_and_tags"] is not None:
-                    inames_and_tags = parse_iname_and_tag_list(
-                            groups["iname_deps_and_tags"])
-                    forced_iname_deps = frozenset(iname for iname, tag in inames_and_tags)
-                    iname_to_tag_requests.update(dict(inames_and_tags))
-                else:
-                    forced_iname_deps = frozenset()
-
-                if groups["duplicate_inames_and_tags"] is not None:
-                    duplicate_inames_and_tags = parse_iname_and_tag_list(
-                            groups["duplicate_inames_and_tags"])
-                else:
-                    duplicate_inames_and_tags = []
-
                 if groups["temp_var_type"] is not None:
                     if groups["temp_var_type"]:
                         temp_var_type = np.dtype(groups["temp_var_type"])
@@ -903,10 +878,9 @@ class LoopKernel(Record):
                             id=self.make_unique_instruction_id(
                                 parsed_instructions, based_on=insn_id),
                             insn_deps=insn_deps,
-                            forced_iname_deps=forced_iname_deps,
+                            forced_iname_deps=frozenset(),
                             assignee=lhs, expression=rhs,
                             temp_var_type=temp_var_type,
-                            duplicate_inames_and_tags=duplicate_inames_and_tags,
                             priority=priority))
 
             elif subst_match is not None:
@@ -930,7 +904,7 @@ class LoopKernel(Record):
 
                 substitutions[subst_name] = SubstitutionRule(
                         name=subst_name,
-                        arguments=arg_names,
+                        arguments=tuple(arg_names),
                         expression=rhs)
 
         def parse_if_necessary(insn):
@@ -1111,11 +1085,16 @@ class LoopKernel(Record):
                 | set(self.all_inames()))
 
     def make_unique_var_name(self, based_on="var", extra_used_vars=set()):
-        used_vars = self.all_variable_names() | extra_used_vars
+        from warnings import warn
+        warn("make_unique_var_name is deprecated, use get_var_name_generator "
+                "instead", DeprecationWarning, stacklevel=2)
 
-        for var_name in _generate_unique_possibilities(based_on):
-            if var_name not in used_vars:
-                return var_name
+        gen = self.get_var_name_generator()
+        gen.add_names(extra_used_vars)
+        return gen(based_on)
+
+    def get_var_name_generator(self):
+        return _UniqueNameGenerator(self.all_variable_names())
 
     def make_unique_instruction_id(self, insns=None, based_on="insn", extra_used_ids=set()):
         if insns is None:
@@ -1623,22 +1602,6 @@ class LoopKernel(Record):
 
     # }}}
 
-    def map_expressions(self, func, exclude_instructions=False):
-        if exclude_instructions:
-            new_insns = self.instructions
-        else:
-            new_insns = [insn.copy(
-                expression=func(insn.expression),
-                assignee=func(insn.assignee),
-                )
-                    for insn in self.instructions]
-
-        return self.copy(
-                instructions=new_insns,
-                substitutions=dict(
-                    (subst.name, subst.copy(expression=func(subst.expression)))
-                    for subst in self.substitutions.itervalues()))
-
     # {{{ pretty-printing
 
     def __str__(self):
@@ -1711,9 +1674,15 @@ def find_all_insn_inames(kernel):
     insn_id_to_inames = {}
     insn_assignee_inames = {}
 
+    all_read_deps = {}
+    all_write_deps = {}
+
+    from loopy.subst import expand_subst
+    kernel = expand_subst(kernel)
+
     for insn in kernel.instructions:
-        read_deps = get_dependencies(insn.expression)
-        write_deps = get_dependencies(insn.assignee)
+        all_read_deps[insn.id] = read_deps = get_dependencies(insn.expression)
+        all_write_deps[insn.id] = write_deps = get_dependencies(insn.assignee)
         deps = read_deps | write_deps
 
         iname_deps = (
@@ -1748,8 +1717,7 @@ def find_all_insn_inames(kernel):
             # of iname deps of all writers, and add those to insn's
             # dependencies.
 
-            for tv_name in (get_dependencies(insn.expression)
-                    & temp_var_names):
+            for tv_name in (all_read_deps[insn.id] & temp_var_names):
                 implicit_inames = None
 
                 for writer_id in writer_map[tv_name]:
@@ -1874,8 +1842,7 @@ class DomainChanger:
 
 # }}}
 
-
-# {{{ dot export
+# {{{ graphviz / dot export
 
 def get_dot_dependency_graph(kernel, iname_cluster=False, iname_edge=True):
     lines = []
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 4b2e29a10..ce5cf8c58 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -229,6 +229,7 @@ def realize_reduction(kernel, insn_id_filter=None):
 
     new_insns = []
     new_temporary_variables = kernel.temporary_variables.copy()
+    orig_temp_var_names = set(kernel.temporary_variables)
 
     from loopy.codegen.expression import TypeInferenceMapper
     type_inf_mapper = TypeInferenceMapper(kernel)
@@ -240,7 +241,7 @@ def realize_reduction(kernel, insn_id_filter=None):
         from pymbolic import var
 
         target_var_name = kernel.make_unique_var_name("acc_"+"_".join(expr.inames),
-                extra_used_vars=set(new_temporary_variables))
+                extra_used_vars=set(new_temporary_variables) - orig_temp_var_names)
         target_var = var(target_var_name)
 
         arg_dtype = type_inf_mapper(expr.expr)
@@ -437,11 +438,20 @@ def duplicate_private_temporaries_for_ilp(kernel):
     # }}}
 
     from pymbolic import var
-    return (kernel
-            .copy(temporary_variables=new_temp_vars)
-            .map_expressions(ExtraInameIndexInserter(
-                dict((var_name, tuple(var(iname) for iname in inames))
-                    for var_name, inames in var_to_new_ilp_inames.iteritems()))))
+    eiii = ExtraInameIndexInserter(
+            dict((var_name, tuple(var(iname) for iname in inames))
+                for var_name, inames in var_to_new_ilp_inames.iteritems()))
+
+
+    new_insns = [
+            insn.copy(
+                assignee=eiii(insn.assignee),
+                expression=eiii(insn.expression))
+            for insn in kernel.instructions]
+
+    return kernel.copy(
+        temporary_variables=new_temp_vars,
+        instructions=new_insns)
 
 # }}}
 
diff --git a/loopy/subst.py b/loopy/subst.py
index 1ed9788ca..eec7d74b9 100644
--- a/loopy/subst.py
+++ b/loopy/subst.py
@@ -91,7 +91,7 @@ def extract_subst(kernel, subst_name, template, parameters):
 
         if urecs:
             if len(urecs) > 1:
-                raise RuntimeError("ambiguous unification of '%s' with template '%s'" 
+                raise RuntimeError("ambiguous unification of '%s' with template '%s'"
                         % (expr, template))
 
             urec, = urecs
@@ -155,7 +155,7 @@ def extract_subst(kernel, subst_name, template, parameters):
     new_substs = {
             subst_name: SubstitutionRule(
                 name=subst_name,
-                arguments=parameters,
+                arguments=tuple(parameters),
                 expression=template,
                 )}
 
@@ -172,27 +172,13 @@ def extract_subst(kernel, subst_name, template, parameters):
 
 
 
-def expand_subst(kernel, subst_name=None):
-    if subst_name is None:
-        rules = kernel.substitutions
-    else:
-        rule = kernel.substitutions[subst_name]
-        rules = {rule.name: rule}
-
+def expand_subst(kernel, ctx_match=None):
     from loopy.symbolic import ParametrizedSubstitutor
-    submap = ParametrizedSubstitutor(rules)
-
-    if subst_name:
-        new_substs = kernel.substitutions.copy()
-        del new_substs[subst_name]
-    else:
-        new_substs = {}
-
-    return (kernel
-            .copy(substitutions=new_substs)
-            .map_expressions(submap))
-
-
+    from loopy.context_matching import parse_stack_match
+    submap = ParametrizedSubstitutor(kernel.substitutions,
+            kernel.get_var_name_generator(),
+            parse_stack_match(ctx_match))
 
+    return submap.map_kernel(kernel)
 
 # vim: foldmethod=marker
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index cbb664a27..860c9ec18 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -27,11 +27,11 @@ THE SOFTWARE.
 
 
 
-from pytools import memoize, memoize_method
+from pytools import memoize, memoize_method, Record
 import pytools.lex
 
 from pymbolic.primitives import (
-        Leaf, AlgebraicLeaf, Variable as VariableBase,
+        Leaf, AlgebraicLeaf, Variable,
         CommonSubexpression)
 
 from pymbolic.mapper import (
@@ -81,14 +81,14 @@ class TypedCSE(CommonSubexpression):
         return dict(dtype=self.dtype)
 
 
-class TaggedVariable(VariableBase):
+class TaggedVariable(Variable):
     """This is an identifier with a tag, such as 'matrix$one', where
     'one' identifies this specific use of the identifier. This mechanism
     may then be used to address these uses--such as by prefetching only
     accesses tagged a certain way.
     """
     def __init__(self, name, tag):
-        VariableBase.__init__(self, name)
+        Variable.__init__(self, name)
         self.tag = tag
 
     def __getinitargs__(self):
@@ -129,13 +129,8 @@ class Reduction(AlgebraicLeaf):
 
     @property
     @memoize_method
-    def untagged_inames(self):
-        return tuple(iname.lstrip("@") for iname in self.inames)
-
-    @property
-    @memoize_method
-    def untagged_inames_set(self):
-        return set(self.untagged_inames)
+    def inames_set(self):
+        return set(self.inames)
 
     mapper_method = intern("map_reduction")
 
@@ -157,14 +152,14 @@ class LinearSubscript(AlgebraicLeaf):
 # {{{ mappers with support for loopy-specific primitives
 
 class IdentityMapperMixin(object):
-    def map_reduction(self, expr):
-        return Reduction(expr.operation, expr.inames, self.rec(expr.expr))
+    def map_reduction(self, expr, *args):
+        return Reduction(expr.operation, expr.inames, self.rec(expr.expr, *args))
 
-    def map_tagged_variable(self, expr):
+    def map_tagged_variable(self, expr, *args):
         # leaf, doesn't change
         return expr
 
-    def map_loopy_function_identifier(self, expr):
+    def map_loopy_function_identifier(self, expr, *args):
         return expr
 
     map_linear_subscript = IdentityMapperBase.map_subscript
@@ -217,9 +212,8 @@ class StringifyMapper(StringifyMapperBase):
 
 class DependencyMapper(DependencyMapperBase):
     def map_reduction(self, expr):
-        from pymbolic.primitives import Variable
         return (self.rec(expr.expr)
-                - set(Variable(iname) for iname in expr.untagged_inames))
+                - set(Variable(iname) for iname in expr.inames))
 
     def map_tagged_variable(self, expr):
         return set([expr])
@@ -257,6 +251,187 @@ class UnidirectionalUnifier(UnidirectionalUnifierBase):
 
 # }}}
 
+# {{{ identity mapper that expands subst rules on the fly
+
+def parse_tagged_name(expr):
+    if isinstance(expr, TaggedVariable):
+        return expr.name, expr.tag
+    elif isinstance(expr, Variable):
+        return expr.name, None
+    else:
+        raise RuntimeError("subst rule name not understood: %s" % expr)
+
+class ExpansionState(Record):
+    """
+    :ivar stack: a tuple representing the current expansion stack, as a tuple
+        of (name, tag) pairs. At the top level, this should be initialized to a
+        tuple with the id of the calling instruction.
+    :ivar arg_context: a dict representing current argument values
+    """
+
+class ExpandingIdentityMapper(IdentityMapper):
+    """Note: the third argument dragged around by this mapper is the
+    current expansion expansion state.
+    """
+
+    def __init__(self, old_subst_rules, make_unique_var_name):
+        self.old_subst_rules = old_subst_rules
+        self.make_unique_var_name = make_unique_var_name
+
+        # maps subst rule (args, bodies) to names
+        self.subst_rule_registry = dict(
+                ((rule.arguments, rule.expression), name)
+                for name, rule in old_subst_rules.iteritems())
+
+        # maps subst rule (args, bodies) to use counts
+        self.subst_rule_use_count = {}
+
+    def register_subst_rule(self, name, args, body):
+        """Returns a name (as a string) for a newly created substitution
+        rule.
+        """
+        key = (args, body)
+        existing_name = self.subst_rule_registry.get(key)
+
+        if existing_name is None:
+            new_name = self.make_unique_var_name(name)
+            self.subst_rule_registry[key] = new_name
+        else:
+            new_name = existing_name
+
+        self.subst_rule_use_count[key] = self.subst_rule_use_count.get(key, 0) + 1
+        return new_name
+
+    def map_variable(self, expr, expn_state):
+        name, tag = parse_tagged_name(expr)
+        if name not in self.old_subst_rules:
+            return IdentityMapper.map_variable(self, expr, expn_state)
+        else:
+            return self.map_substitution(name, tag, (), expn_state)
+
+    def map_call(self, expr, expn_state):
+        if not isinstance(expr.function, Variable):
+            return IdentityMapper.map_call(self, expr, expn_state)
+
+        name, tag = parse_tagged_name(expr.function)
+
+        if name not in self.old_subst_rules:
+            return IdentityMapper.map_call(self, expr, expn_state)
+        else:
+            return self.map_substitution(name, tag, expr.parameters, expn_state)
+
+    @staticmethod
+    def make_new_arg_context(rule_name, arg_names, arguments, arg_context):
+        if len(arg_names) != len(arguments):
+            raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
+                    % (rule_name, len(arguments), len(arg_names), ))
+
+        from pymbolic.mapper.substitutor import make_subst_func
+        arg_subst_map = SubstitutionMapper(make_subst_func(arg_context))
+        return dict(
+                (formal_arg_name, arg_subst_map(arg_value))
+                for formal_arg_name, arg_value in zip(arg_names, arguments))
+
+    def map_substitution(self, name, tag, arguments, expn_state):
+        rule = self.old_subst_rules[name]
+
+        rec_arguments = self.rec(arguments, expn_state)
+        new_expn_state = expn_state.copy(
+                stack=expn_state.stack + ((name, tag),),
+                arg_context=self.make_new_arg_context(
+                    name, rule.arguments, rec_arguments, expn_state.arg_context))
+
+        result = self.rec(rule.expression, new_expn_state)
+
+        new_name = self.register_subst_rule(name, rule.arguments, result)
+
+        if tag is None:
+            sym = Variable(new_name)
+        else:
+            sym = TaggedVariable(new_name, tag)
+
+        if arguments:
+            return sym(*rec_arguments)
+        else:
+            return sym
+
+    def __call__(self, expr, insn_id):
+        if insn_id is not None:
+            stack = (insn_id,)
+        else:
+            stack = ()
+
+        return IdentityMapper.__call__(self, expr, ExpansionState(
+            stack=stack, arg_context={}))
+
+    def get_new_substitutions(self):
+        from loopy.kernel import SubstitutionRule
+
+        result = {}
+        for key, name in self.subst_rule_registry.iteritems():
+            args, body = key
+
+            if self.subst_rule_use_count.get(key, 0):
+                result[name] = SubstitutionRule(
+                        name=name,
+                        arguments=args,
+                        expression=body)
+
+        return result
+
+    def map_kernel(self, kernel):
+        new_insns = [
+                insn.copy(
+                    assignee=self(insn.assignee, insn.id),
+                    expression=self(insn.expression, insn.id))
+                for insn in kernel.instructions]
+
+        return kernel.copy(
+            substitutions=self.get_new_substitutions(),
+            instructions=new_insns)
+
+
+# }}}
+
+# {{{ parametrized substitutor
+
+class ParametrizedSubstitutor(ExpandingIdentityMapper):
+    def __init__(self, rules, make_unique_var=None, ctx_match=None):
+        ExpandingIdentityMapper.__init__(self, rules, make_unique_var)
+
+        if ctx_match is None:
+            from loopy.context_matching import AllStackMatch
+            ctx_match = AllStackMatch()
+
+        self.ctx_match = ctx_match
+
+    def map_substitution(self, name, tag, arguments, expn_state):
+        new_stack = expn_state.stack + ((name, tag),)
+        if self.ctx_match(new_stack):
+            # expand
+            rule = self.old_subst_rules[name]
+
+            new_expn_state = expn_state.copy(
+                    stack=new_stack,
+                    arg_context=self.make_new_arg_context(
+                        name, rule.arguments, arguments, expn_state.arg_context))
+
+            result = self.rec(rule.expression, new_expn_state)
+
+            # substitute in argument values
+            from pymbolic.mapper.substitutor import make_subst_func
+            subst_map = SubstitutionMapper(make_subst_func(
+                new_expn_state.arg_context))
+
+            return subst_map(result)
+
+        else:
+            # do not expand
+            return ExpandingIdentityMapper.map_substitution(
+                    self, name, tag, arguments, expn_state)
+
+# }}}
+
 # {{{ functions to primitives, parsing
 
 class VarToTaggedVarMapper(IdentityMapper):
@@ -343,7 +518,7 @@ class FunctionToPrimitiveMapper(IdentityMapper):
 
         return Reduction(operation, tuple(processed_inames), red_expr)
 
-# {{{ parser extension
+# {{{ customization to pymbolic parser
 
 _open_dbl_bracket = intern("open_dbl_bracket")
 _close_dbl_bracket = intern("close_dbl_bracket")
@@ -392,26 +567,6 @@ def parse(expr_str):
 
 # }}}
 
-# {{{ reduction loop splitter
-
-class ReductionLoopSplitter(IdentityMapper):
-    def __init__(self, old_iname, outer_iname, inner_iname):
-        self.old_iname = old_iname
-        self.outer_iname = outer_iname
-        self.inner_iname = inner_iname
-
-    def map_reduction(self, expr):
-        if self.old_iname in expr.inames:
-            new_inames = list(expr.inames)
-            new_inames.remove(self.old_iname)
-            new_inames.extend([self.outer_iname, self.inner_iname])
-            return Reduction(expr.operation, tuple(new_inames),
-                        expr.expr)
-        else:
-            return IdentityMapper.map_reduction(self, expr)
-
-# }}}
-
 # {{{ coefficient collector
 
 class CoefficientCollector(RecursiveMapper):
@@ -641,134 +796,13 @@ class IndexVariableFinder(CombineMapper):
     def map_reduction(self, expr):
         result = self.rec(expr.expr)
 
-        if not (expr.untagged_inames_set & result):
+        if not (expr.inames_set & result):
             raise RuntimeError("reduction '%s' does not depend on "
                     "reduction inames (%s)" % (expr, ",".join(expr.inames)))
         if self.include_reduction_inames:
             return result
         else:
-            return result - expr.untagged_inames_set
-
-# }}}
-
-# {{{ substitution callback mapper
-
-class SubstitutionCallbackMapper(IdentityMapper):
-    @staticmethod
-    def parse_filter(filt):
-        if not isinstance(filt, tuple):
-            components = filt.split("$")
-            if len(components) == 1:
-                return (components[0], None)
-            elif len(components) == 2:
-                return tuple(components)
-            else:
-                raise RuntimeError("too many components in '%s'" % filt)
-        else:
-            if len(filt) != 2:
-                raise RuntimeError("substitution name filters "
-                        "may have at most two components")
-
-            return filt
-
-    def __init__(self, names_filter, func):
-        if names_filter is not None:
-            new_names_filter = []
-            for filt in names_filter:
-                new_names_filter.append(self.parse_filter(filt))
-
-            self.names_filter = new_names_filter
-        else:
-            self.names_filter = names_filter
-
-        self.func = func
-
-    def parse_name(self, expr):
-        from pymbolic.primitives import Variable
-        if isinstance(expr, TaggedVariable):
-            e_name, e_tag = expr.name, expr.tag
-        elif isinstance(expr, Variable):
-            e_name, e_tag = expr.name, None
-        else:
-            return None
-
-        if self.names_filter is not None:
-            for filt_name, filt_tag in self.names_filter:
-                if e_name == filt_name:
-                    if filt_tag is None or filt_tag == e_tag:
-                        return e_name, e_tag
-        else:
-            return e_name, e_tag
-
-        return None
-
-    def map_variable(self, expr):
-        parsed_name = self.parse_name(expr)
-        if parsed_name is None:
-            return getattr(IdentityMapper, expr.mapper_method)(self, expr)
-
-        name, tag = parsed_name
-
-        result = self.func(expr, name, tag, (), self.rec)
-        if result is None:
-            return getattr(IdentityMapper, expr.mapper_method)(self, expr)
-        else:
-            return result
-
-    map_tagged_variable = map_variable
-
-    def map_call(self, expr):
-        from pymbolic.primitives import Lookup
-        if isinstance(expr.function, Lookup):
-            raise RuntimeError("dotted name '%s' not allowed as "
-                    "function identifier" % expr.function)
-
-        parsed_name = self.parse_name(expr.function)
-
-        if parsed_name is None:
-            return IdentityMapper.map_call(self, expr)
-
-        name, tag = parsed_name
-
-        result = self.func(expr, name, tag, expr.parameters, self.rec)
-        if result is None:
-            return IdentityMapper.map_call(self, expr)
-        else:
-            return result
-
-# }}}
-
-# {{{ parametrized substitutor
-
-class ParametrizedSubstitutor(object):
-    def __init__(self, rules, one_level=False):
-        self.rules = rules
-        self.one_level = one_level
-
-    def __call__(self, expr):
-        level = [0]
-
-        def expand_if_known(expr, name, instance, args, rec):
-            if self.one_level and level[0] > 0:
-                return None
-
-            rule = self.rules[name]
-            if len(rule.arguments) != len(args):
-                raise RuntimeError("Rule '%s' invoked with %d arguments (needs %d)"
-                        % (name, len(args), len(rule.arguments), ))
-
-            from pymbolic.mapper.substitutor import make_subst_func
-            subst_map = SubstitutionMapper(make_subst_func(
-                dict(zip(rule.arguments, args))))
-
-            level[0] += 1
-            result = rec(subst_map(rule.expression))
-            level[0] -= 1
-
-            return result
-
-        scm = SubstitutionCallbackMapper(self.rules.keys(), expand_if_known)
-        return scm(expr)
+            return result - expr.inames_set
 
 # }}}
 
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 75604febb..0bd22021a 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -29,7 +29,6 @@ import numpy as np
 import numpy.linalg as la
 import pyopencl as cl
 import pyopencl.array as cl_array
-import pyopencl.clrandom as cl_random
 import loopy as lp
 
 from pyopencl.tools import pytest_generate_tests_for_pyopencl \
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 84f8e0c21..8798ac6ec 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -39,6 +39,40 @@ __all__ = ["pytest_generate_tests",
 
 
 
+def test_complicated_subst(ctx_factory):
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(ctx.devices[0],
+            "{[i]: 0<=i<n}",
+            """
+                f(x) := x*a[x]
+                g(x) := 12 + f(x)
+                h(x) := 1 + g(x) + 20*g$two(x)
+
+                a[i] = h$one(i) * h$two(i)
+                """,
+            [
+                lp.GlobalArg("a", np.float32, shape=("n",)),
+                lp.ValueArg("n", np.int32),
+                ])
+
+    from loopy.subst import expand_subst
+    knl = expand_subst(knl, "g$two < h$two")
+
+    print knl
+
+    sr_keys = knl.substitutions.keys()
+    for letter, how_many in [
+            ("f", 1),
+            ("g", 1),
+            ("h", 2)
+            ]:
+        substs_with_letter = sum(1 for k in sr_keys if k.startswith(letter))
+        assert substs_with_letter == how_many
+
+
+
+
 def test_type_inference_no_artificial_doubles(ctx_factory):
     ctx = ctx_factory()
 
@@ -135,11 +169,13 @@ def test_owed_barriers(ctx_factory):
     knl = lp.make_kernel(ctx.devices[0],
             "{[i]: 0<=i<100}",
             [
-                "[i:l.0] <float32> z[i] = a[i]"
+                "<float32> z[i] = a[i]"
                 ],
             [lp.GlobalArg("a", np.float32, shape=(100,))]
             )
 
+    knl = lp.tag_inames(knl, dict(i="l.0"))
+
     kernel_gen = lp.generate_loop_schedules(knl)
     kernel_gen = lp.check_kernels(kernel_gen)
 
@@ -156,11 +192,13 @@ def test_wg_too_small(ctx_factory):
     knl = lp.make_kernel(ctx.devices[0],
             "{[i]: 0<=i<100}",
             [
-                "[i:l.0] <float32> z[i] = a[i] {id=copy}"
+                "<float32> z[i] = a[i] {id=copy}"
                 ],
             [lp.GlobalArg("a", np.float32, shape=(100,))],
             local_sizes={0: 16})
 
+    knl = lp.tag_inames(knl, dict(i="l.0"))
+
     kernel_gen = lp.generate_loop_schedules(knl)
     kernel_gen = lp.check_kernels(kernel_gen)
 
@@ -242,7 +280,7 @@ def test_multi_cse(ctx_factory):
     knl = lp.make_kernel(ctx.devices[0],
             "{[i]: 0<=i<100}",
             [
-                "[i] <float32> z[i] = a[i] + a[i]**2"
+                "<float32> z[i] = a[i] + a[i]**2"
                 ],
             [lp.GlobalArg("a", np.float32, shape=(100,))],
             local_sizes={0: 16})
@@ -816,7 +854,7 @@ def test_ilp_write_race_detection_global(ctx_factory):
             "[n] -> {[i,j]: 0<=i,j<n }",
             ],
             [
-                "[j:ilp] a[i] = 5+i+j",
+                "a[i] = 5+i+j",
                 ],
             [
                 lp.GlobalArg("a", np.float32),
@@ -824,6 +862,8 @@ def test_ilp_write_race_detection_global(ctx_factory):
                 ],
             assumptions="n>=1")
 
+    knl = lp.tag_inames(knl, dict(j="ilp"))
+
     from loopy.check import WriteRaceConditionError
     import pytest
     with pytest.raises(WriteRaceConditionError):
@@ -838,10 +878,12 @@ def test_ilp_write_race_avoidance_local(ctx_factory):
     knl = lp.make_kernel(ctx.devices[0],
             "{[i,j]: 0<=i<16 and 0<=j<17 }",
             [
-                "[i:l.0, j:ilp] <> a[i] = 5+i+j",
+                "<> a[i] = 5+i+j",
                 ],
             [])
 
+    knl = lp.tag_inames(knl, dict(i="l.0", j="ilp"))
+
     for k in lp.generate_loop_schedules(knl):
         assert k.temporary_variables["a"].shape == (16,17)
 
@@ -854,10 +896,12 @@ def test_ilp_write_race_avoidance_private(ctx_factory):
     knl = lp.make_kernel(ctx.devices[0],
             "{[j]: 0<=j<16 }",
             [
-                "[j:ilp] <> a = 5+j",
+                "<> a = 5+j",
                 ],
             [])
 
+    knl = lp.tag_inames(knl, dict(j="ilp"))
+
     for k in lp.generate_loop_schedules(knl):
         assert k.temporary_variables["a"].shape == (16,)
 
-- 
GitLab