From acfd14acec6c8f58e0d6cfbeef60124489fc2627 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 19 Apr 2015 19:12:49 -0500
Subject: [PATCH] Refactor rule-aware term rewriting using
 SubstitutionRuleMappingContext

---
 MEMO                     |   2 +-
 loopy/__init__.py        |  85 +++++++++-----
 loopy/buffer.py          |  19 +--
 loopy/kernel/creation.py |  18 +--
 loopy/padding.py         |  16 +--
 loopy/precompute.py      | 116 ++++++++++---------
 loopy/preprocess.py      |   6 +-
 loopy/subst.py           |  40 +++----
 loopy/symbolic.py        | 244 ++++++++++++++++++++++-----------------
 9 files changed, 307 insertions(+), 239 deletions(-)

diff --git a/MEMO b/MEMO
index d650e7ea4..f4e5c34e4 100644
--- a/MEMO
+++ b/MEMO
@@ -142,7 +142,7 @@ Dealt with
 - How can one automatically generate something like microblocks?
   -> Some sort of axis-adding transform?
 
-- ExpandingIdentityMapper
+- RuleAwareIdentityMapper
   extract_subst -> needs WalkMapper [actually fine as is]
   padding [DONE]
   replace make_unique_var_name [DONE]
diff --git a/loopy/__init__.py b/loopy/__init__.py
index 364540856..194a2c7b1 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -30,7 +30,8 @@ THE SOFTWARE.
 import islpy as isl
 from islpy import dim_type
 
-from loopy.symbolic import ExpandingIdentityMapper, ExpandingSubstitutionMapper
+from loopy.symbolic import (RuleAwareIdentityMapper, RuleAwareSubstitutionMapper,
+        SubstitutionRuleMappingContext)
 from loopy.diagnostic import LoopyError
 
 
@@ -126,11 +127,10 @@ __all__ = [
 
 # {{{ split inames
 
-class _InameSplitter(ExpandingIdentityMapper):
-    def __init__(self, kernel, within,
+class _InameSplitter(RuleAwareIdentityMapper):
+    def __init__(self, rule_mapping_context, within,
             split_iname, outer_iname, inner_iname, replacement_index):
-        ExpandingIdentityMapper.__init__(self,
-                kernel.substitutions, kernel.get_var_name_generator())
+        super(_InameSplitter, self).__init__(rule_mapping_context)
 
         self.within = within
 
@@ -277,10 +277,13 @@ def split_iname(kernel, split_iname, inner_length,
     from loopy.context_matching import parse_stack_match
     within = parse_stack_match(within)
 
-    ins = _InameSplitter(kernel, within,
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, kernel.get_var_name_generator())
+    ins = _InameSplitter(rule_mapping_context, within,
             split_iname, outer_iname, inner_iname, new_loop_index)
 
     kernel = ins.map_kernel(kernel)
+    kernel = rule_mapping_context.finish_kernel(kernel)
 
     if existing_tag is not None:
         kernel = tag_inames(kernel,
@@ -293,10 +296,10 @@ def split_iname(kernel, split_iname, inner_length,
 
 # {{{ join inames
 
-class _InameJoiner(ExpandingSubstitutionMapper):
-    def __init__(self, kernel, within, subst_func, joined_inames, new_iname):
-        ExpandingSubstitutionMapper.__init__(self,
-                kernel.substitutions, kernel.get_var_name_generator(),
+class _InameJoiner(RuleAwareSubstitutionMapper):
+    def __init__(self, rule_mapping_context, within, subst_func,
+            joined_inames, new_iname):
+        super(_InameJoiner, self).__init__(rule_mapping_context,
                 subst_func, within)
 
         self.joined_inames = set(joined_inames)
@@ -425,11 +428,14 @@ def join_inames(kernel, inames, new_iname=None, tag=None, within=None):
     within = parse_stack_match(within)
 
     from pymbolic.mapper.substitutor import make_subst_func
-    ijoin = _InameJoiner(kernel, within,
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, kernel.get_var_name_generator())
+    ijoin = _InameJoiner(rule_mapping_context, within,
             make_subst_func(subst_dict),
             inames, new_iname)
 
-    kernel = ijoin.map_kernel(kernel)
+    kernel = rule_mapping_context.finish_kernel(
+            ijoin.map_kernel(kernel))
 
     if tag is not None:
         kernel = tag_inames(kernel, {new_iname: tag})
@@ -494,11 +500,10 @@ def tag_inames(kernel, iname_to_tag, force=False):
 
 # {{{ duplicate inames
 
-class _InameDuplicator(ExpandingIdentityMapper):
-    def __init__(self, rules, make_unique_var_name,
+class _InameDuplicator(RuleAwareIdentityMapper):
+    def __init__(self, rule_mapping_context,
             old_to_new, within):
-        super(_InameDuplicator, self).__init__(
-                rules, make_unique_var_name)
+        super(_InameDuplicator, self).__init__(rule_mapping_context)
 
         self.old_to_new = old_to_new
         self.old_inames_set = set(six.iterkeys(old_to_new))
@@ -602,11 +607,14 @@ def duplicate_inames(knl, inames, within, new_inames=None, suffix=None,
 
     # {{{ change the inames in the code
 
-    indup = _InameDuplicator(knl.substitutions, name_gen,
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            knl.substitutions, name_gen)
+    indup = _InameDuplicator(rule_mapping_context,
             old_to_new=dict(list(zip(inames, new_inames))),
             within=within)
 
-    knl = indup.map_kernel(knl)
+    knl = rule_mapping_context.finish_kernel(
+            indup.map_kernel(knl))
 
     # }}}
 
@@ -732,10 +740,13 @@ def link_inames(knl, inames, new_iname, within=None, tag=None):
     within = parse_stack_match(within)
 
     from pymbolic.mapper.substitutor import make_subst_func
-    ijoin = ExpandingSubstitutionMapper(knl.substitutions, var_name_gen,
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            knl.substitutions, var_name_gen)
+    ijoin = RuleAwareSubstitutionMapper(knl.substitutions, var_name_gen,
                     make_subst_func(subst_dict), within)
 
-    knl = ijoin.map_kernel(knl)
+    knl = rule_mapping_context.finish_kernel(
+            ijoin.map_kernel(knl))
 
     # }}}
 
@@ -1187,10 +1198,10 @@ def tag_data_axes(knl, ary_names, dim_tags):
 
 # {{{ split_reduction
 
-class _ReductionSplitter(ExpandingIdentityMapper):
-    def __init__(self, kernel, within, inames, direction):
+class _ReductionSplitter(RuleAwareIdentityMapper):
+    def __init__(self, rule_mapping_context, within, inames, direction):
         super(_ReductionSplitter, self).__init__(
-                kernel.substitutions, kernel.get_var_name_generator())
+                rule_mapping_context)
 
         self.within = within
         self.inames = inames
@@ -1230,8 +1241,12 @@ def _split_reduction(kernel, inames, direction, within=None):
     from loopy.context_matching import parse_stack_match
     within = parse_stack_match(within)
 
-    rsplit = _ReductionSplitter(kernel, within, inames, direction)
-    return rsplit.map_kernel(kernel)
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, kernel.get_var_name_generator())
+    rsplit = _ReductionSplitter(rule_mapping_context,
+            within, inames, direction)
+    return rule_mapping_context.finish_kernel(
+            rsplit.map_kernel(kernel))
 
 
 def split_reduction_inward(kernel, inames, within=None):
@@ -1328,11 +1343,13 @@ def _fix_parameter(kernel, name, value):
     from loopy.context_matching import parse_stack_match
     within = parse_stack_match(None)
 
-    from loopy.symbolic import ExpandingSubstitutionMapper
-    esubst_map = ExpandingSubstitutionMapper(
-            kernel.substitutions, kernel.get_var_name_generator(),
-            subst_func, within=within)
-    return (esubst_map.map_kernel(kernel)
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, kernel.get_var_name_generator())
+    esubst_map = RuleAwareSubstitutionMapper(
+            rule_mapping_context, subst_func, within=within)
+    return (
+            rule_mapping_context.finish_kernel(
+                esubst_map.map_kernel(kernel))
             .copy(
                 domains=new_domains,
                 args=new_args,
@@ -1633,10 +1650,14 @@ def affine_map_inames(kernel, old_inames, new_inames, equations):
     var_name_gen = kernel.get_var_name_generator()
 
     from pymbolic.mapper.substitutor import make_subst_func
-    old_to_new = ExpandingSubstitutionMapper(kernel.substitutions, var_name_gen,
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, var_name_gen)
+    old_to_new = RuleAwareSubstitutionMapper(rule_mapping_context,
             make_subst_func(subst_dict), within=lambda stack: True)
 
-    kernel = (old_to_new.map_kernel(kernel)
+    kernel = (
+            rule_mapping_context.finish_kernel(
+                old_to_new.map_kernel(kernel))
             .copy(
                 applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_dict]
                 ))
diff --git a/loopy/buffer.py b/loopy/buffer.py
index a5721890e..c64943f14 100644
--- a/loopy/buffer.py
+++ b/loopy/buffer.py
@@ -25,7 +25,8 @@ THE SOFTWARE.
 
 from loopy.array_buffer_map import (ArrayToBufferMap, NoOpArrayToBufferMap,
         AccessDescriptor)
-from loopy.symbolic import (get_dependencies, ExpandingIdentityMapper,
+from loopy.symbolic import (get_dependencies,
+        RuleAwareIdentityMapper, SubstitutionRuleMappingContext,
         SubstitutionMapper)
 from pymbolic.mapper.substitutor import make_subst_func
 
@@ -34,12 +35,11 @@ from pymbolic import var
 
 # {{{ replace array access
 
-class ArrayAccessReplacer(ExpandingIdentityMapper):
-    def __init__(self, kernel, var_name, within, array_base_map, buf_var):
-        super(ArrayAccessReplacer, self).__init__(
-                kernel.substitutions, kernel.get_var_name_generator())
+class ArrayAccessReplacer(RuleAwareIdentityMapper):
+    def __init__(self, rule_mapping_context,
+            var_name, within, array_base_map, buf_var):
+        super(ArrayAccessReplacer, self).__init__(rule_mapping_context)
 
-        self.kernel = kernel
         self.within = within
 
         self.array_base_map = array_base_map
@@ -320,8 +320,11 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
 
     # }}}
 
-    aar = ArrayAccessReplacer(kernel, var_name, within, abm, buf_var)
-    kernel = aar.map_kernel(kernel)
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, kernel.get_var_name_generator())
+    aar = ArrayAccessReplacer(rule_mapping_context, var_name,
+            within, abm, buf_var)
+    kernel = rule_mapping_context.finish_kernel(aar.map_kernel(kernel))
 
     did_write = False
     for insn_id in aar.modified_insn_ids:
diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index a21d983d7..e99dedbf9 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -459,16 +459,16 @@ class ArgumentGuesser:
                 (assignee_var_name, _), = insn.assignees_and_indices()
                 self.all_written_names.add(assignee_var_name)
                 self.all_names.update(get_dependencies(
-                    self.submap(insn.assignee, insn.id, insn.tags)))
+                    self.submap(insn.assignee)))
                 self.all_names.update(get_dependencies(
-                    self.submap(insn.expression, insn.id, insn.tags)))
+                    self.submap(insn.expression)))
 
     def find_index_rank(self, name):
         irf = IndexRankFinder(name)
 
         for insn in self.instructions:
             insn.with_transformed_expressions(
-                    lambda expr: irf(self.submap(expr, insn.id, insn.tags)))
+                    lambda expr: irf(self.submap(expr)))
 
         if not irf.index_ranks:
             return 0
@@ -859,8 +859,7 @@ def guess_arg_shape_if_requested(kernel, default_order):
     from loopy.kernel.array import ArrayBase
     from loopy.symbolic import SubstitutionRuleExpander, AccessRangeMapper
 
-    submap = SubstitutionRuleExpander(kernel.substitutions,
-            kernel.get_var_name_generator())
+    submap = SubstitutionRuleExpander(kernel.substitutions)
 
     for arg in kernel.args:
         if isinstance(arg, ArrayBase) and arg.shape is lp.auto:
@@ -869,11 +868,12 @@ def guess_arg_shape_if_requested(kernel, default_order):
             try:
                 for insn in kernel.instructions:
                     if isinstance(insn, lp.ExpressionInstruction):
-                        armap(submap(insn.assignee, insn.id, insn.tags),
-                                kernel.insn_inames(insn))
-                        armap(submap(insn.expression, insn.id, insn.tags),
-                                kernel.insn_inames(insn))
+                        armap(submap(insn.assignee), kernel.insn_inames(insn))
+                        armap(submap(insn.expression), kernel.insn_inames(insn))
             except TypeError as e:
+                from traceback import print_exc
+                print_exc()
+
                 from loopy.diagnostic import LoopyError
                 raise LoopyError(
                         "Failed to (automatically, as requested) find "
diff --git a/loopy/padding.py b/loopy/padding.py
index 8f747d413..4951c6854 100644
--- a/loopy/padding.py
+++ b/loopy/padding.py
@@ -25,12 +25,12 @@ THE SOFTWARE.
 """
 
 
-from loopy.symbolic import ExpandingIdentityMapper
+from loopy.symbolic import RuleAwareIdentityMapper, SubstitutionRuleMappingContext
 
 
-class ArgAxisSplitHelper(ExpandingIdentityMapper):
-    def __init__(self, rules, var_name_gen, arg_names, handler):
-        ExpandingIdentityMapper.__init__(self, rules, var_name_gen)
+class ArgAxisSplitHelper(RuleAwareIdentityMapper):
+    def __init__(self, rule_mapping_context, arg_names, handler):
+        super(ArgAxisSplitHelper, self).__init__(rule_mapping_context)
         self.arg_names = arg_names
         self.handler = handler
 
@@ -38,7 +38,7 @@ class ArgAxisSplitHelper(ExpandingIdentityMapper):
         if expr.aggregate.name in self.arg_names:
             return self.handler(expr)
         else:
-            return ExpandingIdentityMapper.map_subscript(self, expr, expn_state)
+            return super(ArgAxisSplitHelper, self).map_subscript(expr, expn_state)
 
 
 def split_arg_axis(kernel, args_and_axes, count, auto_split_inames=True):
@@ -205,9 +205,11 @@ def split_arg_axis(kernel, args_and_axes, count, auto_split_inames=True):
 
         return expr.aggregate.index(tuple(idx))
 
-    aash = ArgAxisSplitHelper(kernel.substitutions, var_name_gen,
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, var_name_gen)
+    aash = ArgAxisSplitHelper(rule_mapping_context,
             set(six.iterkeys(arg_to_rest)), split_access_axis)
-    kernel = aash.map_kernel(kernel)
+    kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel))
 
     kernel = kernel.copy(args=new_args)
 
diff --git a/loopy/precompute.py b/loopy/precompute.py
index d37b53d7b..e516dde13 100644
--- a/loopy/precompute.py
+++ b/loopy/precompute.py
@@ -25,8 +25,9 @@ THE SOFTWARE.
 """
 
 
-from loopy.symbolic import (get_dependencies, ExpandingSubstitutionMapper,
-        ExpandingIdentityMapper)
+from loopy.symbolic import (get_dependencies,
+        RuleAwareIdentityMapper, RuleAwareSubstitutionMapper,
+        SubstitutionRuleMappingContext)
 from pymbolic.mapper.substitutor import make_subst_func
 import numpy as np
 
@@ -58,10 +59,9 @@ def storage_axis_exprs(storage_axis_sources, args):
 
 # {{{ gather rule invocations
 
-class RuleInvocationGatherer(ExpandingIdentityMapper):
-    def __init__(self, kernel, subst_name, subst_tag, within):
-        ExpandingIdentityMapper.__init__(self,
-                kernel.substitutions, kernel.get_var_name_generator())
+class RuleInvocationGatherer(RuleAwareIdentityMapper):
+    def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within):
+        super(RuleInvocationGatherer, self).__init__(rule_mapping_context)
 
         from loopy.symbolic import SubstitutionRuleExpander
         self.subst_expander = SubstitutionRuleExpander(
@@ -83,18 +83,17 @@ class RuleInvocationGatherer(ExpandingIdentityMapper):
         process_me = process_me and self.within(expn_state.stack)
 
         if not process_me:
-            return ExpandingIdentityMapper.map_substitution(
-                    self, name, tag, arguments, expn_state)
+            return super(RuleInvocationGatherer, self).map_substitution(
+                    name, tag, arguments, expn_state)
 
-        rule = self.old_subst_rules[name]
+        rule = self.rule_mapping_context.old_subst_rules[name]
         arg_context = self.make_new_arg_context(
                     name, rule.arguments, arguments, expn_state.arg_context)
 
         arg_deps = set()
         for arg_val in six.itervalues(arg_context):
             arg_deps = (arg_deps
-                    | get_dependencies(self.subst_expander(
-                        arg_val, insn_id=None, insn_tags=None)))
+                    | get_dependencies(self.subst_expander(arg_val)))
 
         # FIXME: This is too strict--and the footprint machinery
         # needs to be taught how to deal with locally constant
@@ -109,8 +108,8 @@ class RuleInvocationGatherer(ExpandingIdentityMapper):
                         ", ".join(arg_deps - self.kernel.all_inames()),
                         ))
 
-            return ExpandingIdentityMapper.map_substitution(
-                    self, name, tag, arguments, expn_state)
+            return super(RuleInvocationGatherer, self).map_substitution(
+                    name, tag, arguments, expn_state)
 
         args = [arg_context[arg_name] for arg_name in rule.arguments]
 
@@ -127,16 +126,14 @@ class RuleInvocationGatherer(ExpandingIdentityMapper):
 
 # {{{ replace rule invocation
 
-class RuleInvocationReplacer(ExpandingIdentityMapper):
-    def __init__(self, kernel, subst_name, subst_tag, within,
+class RuleInvocationReplacer(RuleAwareIdentityMapper):
+    def __init__(self, rule_mapping_context, subst_name, subst_tag, within,
             access_descriptors, array_base_map,
             storage_axis_names, storage_axis_sources,
             non1_storage_axis_names,
             target_var_name):
-        ExpandingIdentityMapper.__init__(self,
-                kernel.substitutions, kernel.get_var_name_generator())
+        super(RuleInvocationReplacer, self).__init__(rule_mapping_context)
 
-        self.kernel = kernel
         self.subst_name = subst_name
         self.subst_tag = subst_tag
         self.within = within
@@ -155,12 +152,12 @@ class RuleInvocationReplacer(ExpandingIdentityMapper):
                 name == self.subst_name
                 and self.within(expn_state.stack)
                 and (self.subst_tag is None or self.subst_tag == tag)):
-            return ExpandingIdentityMapper.map_substitution(
-                    self, name, tag, arguments, expn_state)
+            return super(RuleInvocationReplacer, self).map_substitution(
+                    name, tag, arguments, expn_state)
 
         # {{{ check if in footprint
 
-        rule = self.old_subst_rules[name]
+        rule = self.rule_mapping_context.old_subst_rules[name]
         arg_context = self.make_new_arg_context(
                     name, rule.arguments, arguments, expn_state.arg_context)
         args = [arg_context[arg_name] for arg_name in rule.arguments]
@@ -170,8 +167,8 @@ class RuleInvocationReplacer(ExpandingIdentityMapper):
                     self.storage_axis_sources, args))
 
         if not self.array_base_map.is_access_descriptor_in_footprint(accdesc):
-            return ExpandingIdentityMapper.map_substitution(
-                    self, name, tag, arguments, expn_state)
+            return super(RuleInvocationReplacer, self).map_substitution(
+                    name, tag, arguments, expn_state)
 
         # }}}
 
@@ -202,10 +199,9 @@ class RuleInvocationReplacer(ExpandingIdentityMapper):
         if stor_subscript:
             new_outer_expr = new_outer_expr.index(tuple(stor_subscript))
 
-        # Can't possibly be nested, but recurse anyway to
-        # make sure substitution rules referenced below here
-        # do not get thrown away.
-        self.rec(rule.expression, expn_state.copy(arg_context={}))
+        # Can't possibly be nested, and no need to traverse
+        # further as compute expression has already been seen
+        # by rule_mapping_context.
 
         return new_outer_expr
 
@@ -369,7 +365,11 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
     # {{{ gather up invocations in kernel code, finish access_descriptors
 
     if not footprint_generators:
-        invg = RuleInvocationGatherer(kernel, subst_name, subst_tag, within)
+        rule_mapping_context = SubstitutionRuleMappingContext(
+                kernel.substitutions, kernel.get_var_name_generator())
+        invg = RuleInvocationGatherer(
+                rule_mapping_context, kernel, subst_name, subst_tag, within)
+        del rule_mapping_context
 
         import loopy as lp
         for insn in kernel.instructions:
@@ -405,7 +405,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
     submap = SubstitutionRuleExpander(kernel.substitutions)
 
     value_inames = get_dependencies(
-            submap(subst.expression, insn_id=None, insn_tags=None)
+            submap(subst.expression)
             ) & kernel.all_inames()
     if value_inames - expanding_usage_arg_deps < extra_storage_axes:
         raise RuntimeError("unreferenced sweep inames specified: "
@@ -513,6 +513,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
         non1_storage_axis_names = []
         abm = NoOpArrayToBufferMap()
 
+    kernel = kernel.copy(domains=new_kernel_domains)
+
     # {{{ set up compute insn
 
     target_var_name = var_name_gen(based_on=c_subst_name)
@@ -522,24 +524,54 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
         assignee = assignee.index(
                 tuple(var(iname) for iname in non1_storage_axis_names))
 
+    # {{{ process substitutions on compute instruction
+
+    storage_axis_subst_dict = {}
+
+    for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices):
+        if arg_name in non1_storage_axis_names:
+            arg = var(arg_name)
+        else:
+            arg = 0
+
+        storage_axis_subst_dict[
+                prior_storage_axis_name_dict.get(arg_name, arg_name)] = arg+bi
+
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, kernel.get_var_name_generator())
+
+    from loopy.context_matching import AllStackMatch
+    expr_subst_map = RuleAwareSubstitutionMapper(
+            rule_mapping_context,
+            make_subst_func(storage_axis_subst_dict),
+            within=AllStackMatch())
+
+    compute_expression = expr_subst_map(subst.expression, None, None)
+
+    # }}}
+
     from loopy.kernel.data import ExpressionInstruction
     compute_insn_id = kernel.make_unique_instruction_id(based_on=c_subst_name)
     compute_insn = ExpressionInstruction(
             id=compute_insn_id,
             assignee=assignee,
-            expression=subst.expression)
+            expression=compute_expression)
 
     # }}}
 
     # {{{ substitute rule into expressions in kernel (if within footprint)
 
-    invr = RuleInvocationReplacer(kernel, subst_name, subst_tag, within,
+    invr = RuleInvocationReplacer(rule_mapping_context,
+            subst_name, subst_tag, within,
             access_descriptors, abm,
             storage_axis_names, storage_axis_sources,
             non1_storage_axis_names,
             target_var_name)
 
     kernel = invr.map_kernel(kernel)
+    kernel = kernel.copy(
+            instructions=[compute_insn] + kernel.instructions)
+    kernel = rule_mapping_context.finish_kernel(kernel)
 
     # }}}
 
@@ -566,31 +598,9 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
 
     new_temporary_variables[target_var_name] = temp_var
 
-    # }}}
-
     kernel = kernel.copy(
-            domains=new_kernel_domains,
-            instructions=[compute_insn] + kernel.instructions,
             temporary_variables=new_temporary_variables)
 
-    # {{{ process substitutions on compute instruction
-
-    storage_axis_subst_dict = {}
-
-    for arg_name, bi in zip(storage_axis_names, abm.storage_base_indices):
-        if arg_name in non1_storage_axis_names:
-            arg = var(arg_name)
-        else:
-            arg = 0
-
-        storage_axis_subst_dict[prior_storage_axis_name_dict.get(arg_name, arg_name)] = arg+bi
-
-    expr_subst_map = ExpandingSubstitutionMapper(
-            kernel.substitutions, kernel.get_var_name_generator(),
-            make_subst_func(storage_axis_subst_dict),
-            parse_stack_match("... < "+compute_insn_id))
-    kernel = expr_subst_map.map_kernel(kernel)
-
     # }}}
 
     from loopy import tag_inames
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 145272907..082b18895 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -88,8 +88,7 @@ def _infer_var_type(kernel, var_name, type_inf_mapper, subst_expander):
         if not isinstance(writer_insn, lp.ExpressionInstruction):
             continue
 
-        expr = subst_expander(writer_insn.expression,
-                insn_id=writer_insn_id, insn_tags=writer_insn.tags)
+        expr = subst_expander(writer_insn.expression)
 
         try:
             debug("             via expr %s" % expr)
@@ -171,8 +170,7 @@ def infer_unknown_types(kernel, expect_completion=False):
                 ]))
 
     from loopy.symbolic import SubstitutionRuleExpander
-    subst_expander = SubstitutionRuleExpander(kernel.substitutions,
-            kernel.get_var_name_generator())
+    subst_expander = SubstitutionRuleExpander(kernel.substitutions)
 
     # {{{ work on type inference queue
 
diff --git a/loopy/subst.py b/loopy/subst.py
index d1a643f30..412623c90 100644
--- a/loopy/subst.py
+++ b/loopy/subst.py
@@ -27,7 +27,7 @@ THE SOFTWARE.
 
 from loopy.symbolic import (
         get_dependencies, SubstitutionMapper,
-        ExpandingIdentityMapper)
+        RuleAwareIdentityMapper, SubstitutionRuleMappingContext)
 from loopy.diagnostic import LoopyError
 from pymbolic.mapper.substitutor import make_subst_func
 
@@ -200,15 +200,13 @@ def extract_subst(kernel, subst_name, template, parameters=()):
 
 # {{{ temporary_to_subst
 
-class TemporaryToSubstChanger(ExpandingIdentityMapper):
-    def __init__(self, kernel, temp_name, definition_insn_ids,
+class TemporaryToSubstChanger(RuleAwareIdentityMapper):
+    def __init__(self, rule_mapping_context, temp_name, definition_insn_ids,
             usage_to_definition, within):
-        self.var_name_gen = kernel.get_var_name_generator()
+        self.var_name_gen = rule_mapping_context.make_unique_var_name
 
-        super(TemporaryToSubstChanger, self).__init__(
-                kernel.substitutions, self.var_name_gen)
+        super(TemporaryToSubstChanger, self).__init__(rule_mapping_context)
 
-        self.kernel = kernel
         self.temp_name = temp_name
         self.definition_insn_ids = definition_insn_ids
         self.usage_to_definition = usage_to_definition
@@ -349,10 +347,13 @@ def temporary_to_subst(kernel, temp_name, within=None):
     from loopy.context_matching import parse_stack_match
     within = parse_stack_match(within)
 
-    tts = TemporaryToSubstChanger(kernel, temp_name, definition_insn_ids,
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, kernel.get_var_name_generator())
+    tts = TemporaryToSubstChanger(rule_mapping_context,
+            temp_name, definition_insn_ids,
             usage_to_definition, within)
 
-    kernel = tts.map_kernel(kernel)
+    kernel = rule_mapping_context.finish_kernel(tts.map_kernel(kernel))
 
     from loopy.kernel.data import SubstitutionRule
 
@@ -412,19 +413,18 @@ def temporary_to_subst(kernel, temp_name, within=None):
 # }}}
 
 
-def expand_subst(kernel, ctx_match=None):
+def expand_subst(kernel, within=None):
     logger.debug("%s: expand subst" % kernel.name)
 
-    from loopy.symbolic import SubstitutionRuleExpander
+    from loopy.symbolic import RuleAwareSubstitutionRuleExpander
     from loopy.context_matching import parse_stack_match
-    submap = SubstitutionRuleExpander(kernel.substitutions,
-            kernel.get_var_name_generator(),
-            parse_stack_match(ctx_match))
-
-    kernel = submap.map_kernel(kernel)
-    if ctx_match is None:
-        return kernel.copy(substitutions={})
-    else:
-        return kernel
+    rule_mapping_context = SubstitutionRuleMappingContext(
+            kernel.substitutions, kernel.get_var_name_generator())
+    submap = RuleAwareSubstitutionRuleExpander(
+            rule_mapping_context,
+            kernel.substitutions,
+            parse_stack_match(within))
+
+    return rule_mapping_context.finish_kernel(submap.map_kernel(kernel))
 
 # vim: foldmethod=marker
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index a17c4cdd5..8897f3593 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -1,11 +1,8 @@
 """Pymbolic mappers for loopy."""
 
-from __future__ import division
-from __future__ import absolute_import
+from __future__ import division, absolute_import
 import six
-from six.moves import range
-from six.moves import zip
-from functools import reduce
+from six.moves import range, zip, reduce
 
 __copyright__ = "Copyright (C) 2012 Andreas Kloeckner"
 
@@ -197,6 +194,41 @@ class DependencyMapper(DependencyMapperBase):
 
     map_linear_subscript = DependencyMapperBase.map_subscript
 
+
+class SubstitutionRuleExpander(IdentityMapper):
+    def __init__(self, rules):
+        self.rules = rules
+
+    def map_variable(self, expr):
+        if expr.name in self.rules:
+            return self.map_substitution(expr.name, self.rules[expr.name], ())
+        else:
+            return super(SubstitutionRuleExpander, self).map_variable(expr)
+
+    def map_call(self, expr):
+        if expr.function.name in self.rules:
+            return self.map_substitution(
+                    expr.function.name,
+                    self.rules[expr.function.name],
+                    expr.parameters)
+        else:
+            return super(SubstitutionRuleExpander, self).map_call(expr)
+
+    def map_substitution(self, name, rule, arguments):
+        if len(rule.arguments) != len(arguments):
+            from loopy.diagnostic import LoopyError
+            raise LoopyError("number of arguments to '%s' does not match "
+                    "definition" % name)
+
+        from pymbolic.mapper.substitutor import make_subst_func
+        submap = SubstitutionMapper(
+                make_subst_func(
+                    dict(zip(rule.arguments, arguments))))
+
+        expr = submap(rule.expression)
+
+        return self.rec(expr)
+
 # }}}
 
 
@@ -333,7 +365,7 @@ def get_dependencies(expr):
     return frozenset(dep.name for dep in dep_mapper(expr))
 
 
-# {{{ identity mapper that expands subst rules on the fly
+# {{{ rule-aware mappers
 
 def parse_tagged_name(expr):
     if isinstance(expr, TaggedVariable):
@@ -409,14 +441,7 @@ def rename_subst_rules_in_instructions(insns, renames):
             for insn in insns]
 
 
-class ExpandingIdentityMapper(IdentityMapper):
-    """Note: the third argument dragged around by this mapper is the
-    current :class:`ExpansionState`.
-
-    Subclasses of this must be careful to not touch identifiers that
-    are in :attr:`ExpansionState.arg_context`.
-    """
-
+class SubstitutionRuleMappingContext(object):
     def __init__(self, old_subst_rules, make_unique_var_name):
         self.old_subst_rules = old_subst_rules
         self.make_unique_var_name = make_unique_var_name
@@ -445,9 +470,84 @@ class ExpandingIdentityMapper(IdentityMapper):
         self.subst_rule_use_count[key] = self.subst_rule_use_count.get(key, 0) + 1
         return new_name
 
+    def _get_new_substitutions_and_renames(self):
+        """This makes a new dictionary of substitutions from the ones
+        encountered in mapping all the encountered expressions.
+        It tries hard to keep substitution names the same--i.e.
+        if all derivative versions of a substitution rule ended
+        up with the same mapped version, then this version should
+        retain the name that the substitution rule had previously.
+        Unfortunately, this can't be done in a single pass, and so
+        the routine returns an additional dictionary *subst_renames*
+        of renamings to be performed on the processed expressions.
+
+        The returned substitutions already have the rename applied
+        to them.
+
+        :returns: (new_substitutions, subst_renames)
+        """
+
+        from loopy.kernel.data import SubstitutionRule
+
+        orig_name_histogram = {}
+        for key, (name, orig_name) in six.iteritems(self.subst_rule_registry):
+            if self.subst_rule_use_count.get(key, 0):
+                orig_name_histogram[orig_name] = \
+                        orig_name_histogram.get(orig_name, 0) + 1
+
+        result = {}
+        renames = {}
+
+        for key, (name, orig_name) in six.iteritems(self.subst_rule_registry):
+            args, body = key
+
+            if self.subst_rule_use_count.get(key, 0):
+                if orig_name_histogram[orig_name] == 1 and name != orig_name:
+                    renames[name] = orig_name
+                    name = orig_name
+
+                result[name] = SubstitutionRule(
+                        name=name,
+                        arguments=args,
+                        expression=body)
+
+        # {{{ perform renames on new substitutions
+
+        subst_renamer = SubstitutionRuleRenamer(renames)
+
+        renamed_result = {}
+        for name, rule in six.iteritems(result):
+            renamed_result[name] = rule.copy(
+                    expression=subst_renamer(rule.expression))
+
+        # }}}
+
+        return renamed_result, renames
+
+    def finish_kernel(self, kernel):
+        new_substs, renames = self._get_new_substitutions_and_renames()
+
+        new_insns = rename_subst_rules_in_instructions(kernel.instructions, renames)
+
+        return kernel.copy(
+            substitutions=new_substs,
+            instructions=new_insns)
+
+
+class RuleAwareIdentityMapper(IdentityMapper):
+    """Note: the third argument dragged around by this mapper is the
+    current :class:`ExpansionState`.
+
+    Subclasses of this must be careful to not touch identifiers that
+    are in :attr:`ExpansionState.arg_context`.
+    """
+
+    def __init__(self, rule_mapping_context):
+        self.rule_mapping_context = rule_mapping_context
+
     def map_variable(self, expr, expn_state):
         name, tag = parse_tagged_name(expr)
-        if name not in self.old_subst_rules:
+        if name not in self.rule_mapping_context.old_subst_rules:
             return IdentityMapper.map_variable(self, expr, expn_state)
         else:
             return self.map_substitution(name, tag, (), expn_state)
@@ -458,8 +558,8 @@ class ExpandingIdentityMapper(IdentityMapper):
 
         name, tag = parse_tagged_name(expr.function)
 
-        if name not in self.old_subst_rules:
-            return IdentityMapper.map_call(self, expr, expn_state)
+        if name not in self.rule_mapping_context.old_subst_rules:
+            return super(RuleAwareIdentityMapper, self).map_call(expr, expn_state)
         else:
             return self.map_substitution(name, tag, expr.parameters, expn_state)
 
@@ -476,7 +576,7 @@ class ExpandingIdentityMapper(IdentityMapper):
                 for formal_arg_name, arg_value in zip(arg_names, arguments))
 
     def map_substitution(self, name, tag, arguments, expn_state):
-        rule = self.old_subst_rules[name]
+        rule = self.rule_mapping_context.old_subst_rules[name]
 
         rec_arguments = self.rec(arguments, expn_state)
 
@@ -492,7 +592,8 @@ class ExpandingIdentityMapper(IdentityMapper):
 
         result = self.rec(rule.expression, new_expn_state)
 
-        new_name = self.register_subst_rule(name, rule.arguments, result)
+        new_name = self.rule_mapping_context.register_subst_rule(
+                name, rule.arguments, result)
 
         if tag is None:
             sym = Variable(new_name)
@@ -513,60 +614,6 @@ class ExpandingIdentityMapper(IdentityMapper):
         return IdentityMapper.__call__(self, expr, ExpansionState(
             stack=stack, arg_context={}))
 
-    def _get_new_substitutions_and_renames(self):
-        """This makes a new dictionary of substitutions from the ones
-        encountered in mapping all the encountered expressions.
-        It tries hard to keep substitution names the same--i.e.
-        if all derivative versions of a substitution rule ended
-        up with the same mapped version, then this version should
-        retain the name that the substitution rule had previously.
-        Unfortunately, this can't be done in a single pass, and so
-        the routine returns an additional dictionary *subst_renames*
-        of renamings to be performed on the processed expressions.
-
-        The returned substitutions already have the rename applied
-        to them.
-
-        :returns: (new_substitutions, subst_renames)
-        """
-
-        from loopy.kernel.data import SubstitutionRule
-
-        orig_name_histogram = {}
-        for key, (name, orig_name) in six.iteritems(self.subst_rule_registry):
-            if self.subst_rule_use_count.get(key, 0):
-                orig_name_histogram[orig_name] = \
-                        orig_name_histogram.get(orig_name, 0) + 1
-
-        result = {}
-        renames = {}
-
-        for key, (name, orig_name) in six.iteritems(self.subst_rule_registry):
-            args, body = key
-
-            if self.subst_rule_use_count.get(key, 0):
-                if orig_name_histogram[orig_name] == 1 and name != orig_name:
-                    renames[name] = orig_name
-                    name = orig_name
-
-                result[name] = SubstitutionRule(
-                        name=name,
-                        arguments=args,
-                        expression=body)
-
-        # {{{ perform renames on new substitutions
-
-        subst_renamer = SubstitutionRuleRenamer(renames)
-
-        renamed_result = {}
-        for name, rule in six.iteritems(result):
-            renamed_result[name] = rule.copy(
-                    expression=subst_renamer(rule.expression))
-
-        # }}}
-
-        return renamed_result, renames
-
     def map_instruction(self, insn):
         return insn
 
@@ -575,24 +622,16 @@ class ExpandingIdentityMapper(IdentityMapper):
                 # While subst rules are not allowed in assignees, the mapper
                 # may perform tasks entirely unrelated to subst rules, so
                 # we must map assignees, too.
-
-                insn.with_transformed_expressions(self, insn.id, insn.tags)
+                self.map_instruction(
+                    insn.with_transformed_expressions(self, insn.id, insn.tags))
                 for insn in kernel.instructions]
 
-        new_substs, renames = self._get_new_substitutions_and_renames()
+        return kernel.copy(instructions=new_insns)
 
-        new_insns = [self.map_instruction(insn)
-                for insn in rename_subst_rules_in_instructions(
-                    new_insns, renames)]
-
-        return kernel.copy(
-            substitutions=new_substs,
-            instructions=new_insns)
 
-
-class ExpandingSubstitutionMapper(ExpandingIdentityMapper):
-    def __init__(self, rules, make_unique_var_name, subst_func, within):
-        ExpandingIdentityMapper.__init__(self, rules, make_unique_var_name)
+class RuleAwareSubstitutionMapper(RuleAwareIdentityMapper):
+    def __init__(self, rule_mapping_context, subst_func, within):
+        super(RuleAwareSubstitutionMapper, self).__init__(rule_mapping_context)
 
         self.subst_func = subst_func
         self.within = within
@@ -600,28 +639,23 @@ class ExpandingSubstitutionMapper(ExpandingIdentityMapper):
     def map_variable(self, expr, expn_state):
         if (expr.name in expn_state.arg_context
                 or not self.within(expn_state.stack)):
-            return ExpandingIdentityMapper.map_variable(self, expr, expn_state)
+            return super(RuleAwareSubstitutionMapper, self).map_variable(
+                    expr, expn_state)
 
         result = self.subst_func(expr)
         if result is not None:
             return result
         else:
-            return ExpandingIdentityMapper.map_variable(self, expr, expn_state)
+            return super(RuleAwareSubstitutionMapper, self).map_variable(
+                    expr, expn_state)
 
-# }}}
 
+class RuleAwareSubstitutionRuleExpander(RuleAwareIdentityMapper):
+    def __init__(self, rule_mapping_context, rules, within):
+        super(RuleAwareSubstitutionRuleExpander, self).__init__(rule_mapping_context)
 
-# {{{ substitution rule expander
-
-class SubstitutionRuleExpander(ExpandingIdentityMapper):
-    def __init__(self, rules, make_unique_var=None, ctx_match=None):
-        ExpandingIdentityMapper.__init__(self, rules, make_unique_var)
-
-        if ctx_match is None:
-            from loopy.context_matching import AllStackMatch
-            ctx_match = AllStackMatch()
-
-        self.ctx_match = ctx_match
+        self.rules = rules
+        self.within = within
 
     def map_substitution(self, name, tag, arguments, expn_state):
         if tag is None:
@@ -631,9 +665,9 @@ class SubstitutionRuleExpander(ExpandingIdentityMapper):
 
         new_stack = expn_state.stack + ((name, tags),)
 
-        if self.ctx_match(new_stack):
+        if self.within(new_stack):
             # expand
-            rule = self.old_subst_rules[name]
+            rule = self.rules[name]
 
             new_expn_state = expn_state.copy(
                     stack=new_stack,
@@ -651,8 +685,8 @@ class SubstitutionRuleExpander(ExpandingIdentityMapper):
 
         else:
             # do not expand
-            return ExpandingIdentityMapper.map_substitution(
-                    self, name, tag, arguments, expn_state)
+            return super(RuleAwareSubstitutionRuleExpander, self).map_substitution(
+                    name, tag, arguments, expn_state)
 
 # }}}
 
-- 
GitLab