From 90597d80b55f35f48492fece69b25ee19036733f Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 22 Jun 2015 22:20:42 -0500
Subject: [PATCH] Implement kernel query language

---
 doc/reference.rst         |   2 +-
 loopy/__init__.py         |  63 ++++---
 loopy/buffer.py           |  12 +-
 loopy/context_matching.py | 388 ++++++++++++++++++++++++++++----------
 loopy/fusion.py           |   4 +-
 loopy/precompute.py       |  20 +-
 loopy/subst.py            |   5 +-
 loopy/symbolic.py         |  35 ++--
 test/test_fortran.py      |   2 +-
 test/test_loopy.py        |   2 +-
 test/test_sem_reagan.py   |  10 +-
 11 files changed, 384 insertions(+), 159 deletions(-)

diff --git a/doc/reference.rst b/doc/reference.rst
index 6ed83d430..c7435bbf7 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -356,7 +356,7 @@ TODO: Matching instruction tags
 
 .. automodule:: loopy.context_matching
 
-.. autofunction:: parse_id_match
+.. autofunction:: parse_match
 
 .. autofunction:: parse_stack_match
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index dcd8e27c1..58029467c 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -152,7 +152,10 @@ class _InameSplitter(RuleAwareIdentityMapper):
     def map_reduction(self, expr, expn_state):
         if (self.split_iname in expr.inames
                 and self.split_iname not in expn_state.arg_context
-                and self.within(expn_state.stack)):
+                and self.within(
+                    expn_state.kernel,
+                    expn_state.instruction,
+                    expn_state.stack)):
             new_inames = list(expr.inames)
             new_inames.remove(self.split_iname)
             new_inames.extend([self.outer_iname, self.inner_iname])
@@ -166,7 +169,10 @@ class _InameSplitter(RuleAwareIdentityMapper):
     def map_variable(self, expr, expn_state):
         if (expr.name == self.split_iname
                 and self.split_iname not in expn_state.arg_context
-                and self.within(expn_state.stack)):
+                and self.within(
+                    expn_state.kernel,
+                    expn_state.instruction,
+                    expn_state.stack)):
             return self.replacement_index
         else:
             return super(_InameSplitter, self).map_variable(expr, expn_state)
@@ -318,7 +324,10 @@ class _InameJoiner(RuleAwareSubstitutionMapper):
         expr_inames = set(expr.inames)
         overlap = (self.join_inames & expr_inames
                 - set(expn_state.arg_context))
-        if overlap and self.within(expn_state.stack):
+        if overlap and self.within(
+                expn_state.kernel,
+                expn_state.instruction,
+                expn_state.stack):
             if overlap != expr_inames:
                 raise LoopyError(
                         "Cannot join inames '%s' if there is a reduction "
@@ -520,7 +529,10 @@ class _InameDuplicator(RuleAwareIdentityMapper):
 
     def map_reduction(self, expr, expn_state):
         if (set(expr.inames) & self.old_inames_set
-                and self.within(expn_state.stack)):
+                and self.within(
+                    expn_state.kernel,
+                    expn_state.instruction,
+                    expn_state.stack)):
             new_inames = tuple(
                     self.old_to_new.get(iname, iname)
                     if iname not in expn_state.arg_context
@@ -538,14 +550,17 @@ class _InameDuplicator(RuleAwareIdentityMapper):
 
         if (new_name is None
                 or expr.name in expn_state.arg_context
-                or not self.within(expn_state.stack)):
+                or not self.within(
+                    expn_state.kernel,
+                    expn_state.instruction,
+                    expn_state.stack)):
             return super(_InameDuplicator, self).map_variable(expr, expn_state)
         else:
             from pymbolic import var
             return var(new_name)
 
-    def map_instruction(self, insn):
-        if not self.within(((insn.id, insn.tags),)):
+    def map_instruction(self, kernel, insn):
+        if not self.within(kernel, insn, ()):
             return insn
 
         new_fid = frozenset(
@@ -1050,7 +1065,7 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
     # If the rule survived past precompute() (i.e. some accesses fell outside
     # the footprint), get rid of it before moving on.
     if rule_name in new_kernel.substitutions:
-        return expand_subst(new_kernel, rule_name)
+        return expand_subst(new_kernel, "id:"+rule_name)
     else:
         return new_kernel
 
@@ -1060,19 +1075,19 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None,
 # {{{ instruction processing
 
 def find_instructions(kernel, insn_match):
-    from loopy.context_matching import parse_id_match
-    match = parse_id_match(insn_match)
-    return [insn for insn in kernel.instructions if match(insn.id, insn.tags)]
+    from loopy.context_matching import parse_match
+    match = parse_match(insn_match)
+    return [insn for insn in kernel.instructions if match(kernel, insn)]
 
 
 def map_instructions(kernel, insn_match, f):
-    from loopy.context_matching import parse_id_match
-    match = parse_id_match(insn_match)
+    from loopy.context_matching import parse_match
+    match = parse_match(insn_match)
 
     new_insns = []
 
     for insn in kernel.instructions:
-        if match(insn.id, None):
+        if match(kernel, insn):
             new_insns.append(f(insn))
         else:
             new_insns.append(insn)
@@ -1084,7 +1099,7 @@ def set_instruction_priority(kernel, insn_match, priority):
     """Set the priority of instructions matching *insn_match* to *priority*.
 
     *insn_match* may be any instruction id match understood by
-    :func:`loopy.context_matching.parse_id_match`.
+    :func:`loopy.context_matching.parse_match`.
     """
 
     def set_prio(insn):
@@ -1098,7 +1113,7 @@ def add_dependency(kernel, insn_match, dependency):
     by *insn_match*.
 
     *insn_match* may be any instruction id match understood by
-    :func:`loopy.context_matching.parse_id_match`.
+    :func:`loopy.context_matching.parse_match`.
     """
 
     def add_dep(insn):
@@ -1220,7 +1235,11 @@ class _ReductionSplitter(RuleAwareIdentityMapper):
             # FIXME
             raise NotImplementedError()
 
-        if self.inames <= set(expr.inames) and self.within(expn_state.stack):
+        if (self.inames <= set(expr.inames)
+                and self.within(
+                    expn_state.kernel,
+                    expn_state.instruction,
+                    expn_state.stack)):
             leftover_inames = set(expr.inames) - self.inames
 
             from loopy.symbolic import Reduction
@@ -1659,10 +1678,12 @@ def affine_map_inames(kernel, old_inames, new_inames, equations):
     var_name_gen = kernel.get_var_name_generator()
 
     from pymbolic.mapper.substitutor import make_subst_func
+    from loopy.context_matching import parse_stack_match
+
     rule_mapping_context = SubstitutionRuleMappingContext(
             kernel.substitutions, var_name_gen)
     old_to_new = RuleAwareSubstitutionMapper(rule_mapping_context,
-            make_subst_func(subst_dict), within=lambda stack: True)
+            make_subst_func(subst_dict), within=parse_stack_match(None))
 
     kernel = (
             rule_mapping_context.finish_kernel(
@@ -1792,12 +1813,12 @@ def fold_constants(kernel):
 # {{{ tag_instructions
 
 def tag_instructions(kernel, new_tag, within=None):
-    from loopy.context_matching import parse_stack_match
-    within = parse_stack_match(within)
+    from loopy.context_matching import parse_match
+    within = parse_match(within)
 
     new_insns = []
     for insn in kernel.instructions:
-        if within(((insn.id, insn.tags),)):
+        if within(kernel, insn):
             new_insns.append(
                     insn.copy(tags=insn.tags + (new_tag,)))
         else:
diff --git a/loopy/buffer.py b/loopy/buffer.py
index d155dba7e..fdc3774b2 100644
--- a/loopy/buffer.py
+++ b/loopy/buffer.py
@@ -51,7 +51,10 @@ class ArrayAccessReplacer(RuleAwareIdentityMapper):
 
     def map_variable(self, expr, expn_state):
         result = None
-        if expr.name == self.var_name and self.within(expn_state):
+        if expr.name == self.var_name and self.within(
+                expn_state.kernel,
+                expn_state.instruction,
+                expn_state.stack):
             result = self.map_array_access((), expn_state)
 
         if result is None:
@@ -62,7 +65,10 @@ class ArrayAccessReplacer(RuleAwareIdentityMapper):
 
     def map_subscript(self, expr, expn_state):
         result = None
-        if expr.aggregate.name == self.var_name and self.within(expn_state):
+        if expr.aggregate.name == self.var_name and self.within(
+                expn_state.kernel,
+                expn_state.instruction,
+                expn_state.stack):
             result = self.map_array_access(expr.index, expn_state)
 
         if result is None:
@@ -172,7 +178,7 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None,
 
     access_descriptors = []
     for insn in kernel.instructions:
-        if not within((insn.id, insn.tags)):
+        if not within(kernel, insn.id, ()):
             continue
 
         for assignee, index in insn.assignees_and_indices():
diff --git a/loopy/context_matching.py b/loopy/context_matching.py
index b259a0ddd..cb0d6e956 100644
--- a/loopy/context_matching.py
+++ b/loopy/context_matching.py
@@ -29,148 +29,323 @@ THE SOFTWARE.
 
 NoneType = type(None)
 
+from pytools.lex import RE
 
-# {{{ id match objects
 
-class AllMatch(object):
-    def __call__(self, identifier, tag):
+def re_from_glob(s):
+    import re
+    from fnmatch import translate
+    return re.compile("^"+translate(s.strip())+"$")
+
+# {{{ parsing
+
+# {{{ lexer data
+
+_and = intern("and")
+_or = intern("or")
+_not = intern("not")
+_openpar = intern("openpar")
+_closepar = intern("closepar")
+
+_id = intern("_id")
+_tag = intern("_tag")
+_writes = intern("_writes")
+_reads = intern("_reads")
+_iname = intern("_reads")
+
+_whitespace = intern("_whitespace")
+
+# }}}
+
+
+_LEX_TABLE = [
+    (_and, RE(r"and\b")),
+    (_or, RE(r"or\b")),
+    (_not, RE(r"not\b")),
+    (_openpar, RE(r"\(")),
+    (_closepar, RE(r"\)")),
+
+    # TERMINALS
+    (_id, RE(r"id:([\w?*]+)")),
+    (_tag, RE(r"tag:([\w?*]+)")),
+    (_writes, RE(r"writes:([\w?*]+)")),
+    (_reads, RE(r"reads:([\w?*]+)")),
+    (_iname, RE(r"iname:([\w?*]+)")),
+
+    (_whitespace, RE("[ \t]+")),
+    ]
+
+
+_TERMINALS = ([_id, _tag, _writes, _reads, _iname])
+
+# {{{ operator precedence
+
+_PREC_OR = 10
+_PREC_AND = 20
+_PREC_NOT = 30
+
+# }}}
+
+
+# {{{ match expression
+
+class MatchExpressionBase(object):
+    def __call__(self, kernel, matchable):
+        raise NotImplementedError
+
+
+class AllMatchExpression(MatchExpressionBase):
+    def __call__(self, kernel, matchable):
         return True
 
 
-class RegexIdentifierMatch(object):
-    def __init__(self, id_re, tag_re=None):
-        self.id_re = id_re
-        self.tag_re = tag_re
+class AndMatchExpression(MatchExpressionBase):
+    def __init__(self, children):
+        self.children = children
+
+    def __call__(self, kernel, matchable):
+        return all(ch(kernel, matchable) for ch in self.children)
+
+    def __str__(self):
+        return "(%s)" % (" and ".join(str(ch) for ch in self.children))
+
+
+class OrMatchExpression(MatchExpressionBase):
+    def __init__(self, children):
+        self.children = children
+
+    def __call__(self, kernel, matchable):
+        return any(ch(kernel, matchable) for ch in self.children)
 
-    def __call__(self, identifier, tags):
-        assert isinstance(tags, (tuple, NoneType))
+    def __str__(self):
+        return "(%s)" % (" or ".join(str(ch) for ch in self.children))
+
+
+class NotMatchExpression(MatchExpressionBase):
+    def __init__(self, child):
+        self.child = child
+
+    def __call__(self, kernel, matchable):
+        return not self.child(kernel, matchable)
+
+    def __str__(self):
+        return "(not %s)" % str(self.child)
+
+
+class GlobMatchExpressionBase(MatchExpressionBase):
+    def __init__(self, glob):
+        self.glob = glob
+
+        import re
+        from fnmatch import translate
+        self.re = re.compile("^"+translate(glob.strip())+"$")
 
-        if self.tag_re is None:
-            return self.id_re.match(identifier) is not None
+    def __str__(self):
+        descr = type(self).__name__
+        descr = descr[:descr.find("Match")]
+        return descr.lower() + ":" + self.glob
+
+
+class IdMatchExpression(GlobMatchExpressionBase):
+    def __call__(self, kernel, matchable):
+        return self.re.match(matchable.id)
+
+
+class TagMatchExpression(GlobMatchExpressionBase):
+    def __call__(self, kernel, matchable):
+        if matchable.tags:
+            return any(self.re.match(tag) for tag in matchable.tags)
         else:
-            if not tags:
-                tags = ("",)
+            return False
 
-            return (
-                    self.id_re.match(identifier) is not None
-                    and any(
-                        self.tag_re.match(tag) is not None
-                        for tag in tags))
 
+class WritesMatchExpression(GlobMatchExpressionBase):
+    def __call__(self, kernel, matchable):
+        return any(self.re.match(name)
+                for name in matchable.write_dependency_names())
 
-class AlternativeMatch(object):
-    def __init__(self, matches):
-        self.matches = matches
 
-    def __call__(self, identifier, tags):
-        from pytools import any
-        return any(
-                mtch(identifier, tags) for mtch in self.matches)
+class ReadsMatchExpression(GlobMatchExpressionBase):
+    def __call__(self, kernel, matchable):
+        return any(self.re.match(name)
+                for name in matchable.read_dependency_names())
+
+
+class InameMatchExpression(GlobMatchExpressionBase):
+    def __call__(self, kernel, matchable):
+        return any(self.re.match(name)
+                for name in matchable.inames(kernel))
 
 # }}}
 
 
-# {{{ single id match parsing
+# {{{ parser
 
-def parse_id_match(id_matches):
+def parse_match(expr_str):
     """Syntax examples::
 
-        my_insn
-        compute_*
-        fetch*$first
-        fetch*$first,store*$first
-
-    Alternatively, a list of *(name_glob, tag_glob)* tuples.
+    * ``id:yoink and writes:a_temp``
+    * ``id:yoink and (not writes:a_temp or tagged:input)``
     """
+    if not expr_str:
+        return AllMatchExpression()
+
+    def parse_terminal(pstate):
+        next_tag = pstate.next_tag()
+        if next_tag is _id:
+            result = IdMatchExpression(pstate.next_match_obj().group(1))
+            pstate.advance()
+            return result
+        elif next_tag is _tag:
+            result = TagMatchExpression(pstate.next_match_obj().group(1))
+            pstate.advance()
+            return result
+        elif next_tag is _writes:
+            result = WritesMatchExpression(pstate.next_match_obj().group(1))
+            pstate.advance()
+            return result
+        elif next_tag is _reads:
+            result = ReadsMatchExpression(pstate.next_match_obj().group(1))
+            pstate.advance()
+            return result
+        elif next_tag is _iname:
+            result = InameMatchExpression(pstate.next_match_obj().group(1))
+            pstate.advance()
+            return result
+        else:
+            pstate.expected("terminal")
+
+    def inner_parse(pstate, min_precedence=0):
+        pstate.expect_not_end()
+
+        if pstate.is_next(_not):
+            pstate.advance()
+            left_query = NotMatchExpression(inner_parse(pstate, _PREC_NOT))
+        elif pstate.is_next(_openpar):
+            pstate.advance()
+            left_query = inner_parse(pstate)
+            pstate.expect(_closepar)
+            pstate.advance()
+        else:
+            left_query = parse_terminal(pstate)
 
-    if id_matches is None:
-        return AllMatch()
+        did_something = True
+        while did_something:
+            did_something = False
+            if pstate.is_at_end():
+                return left_query
 
-    if isinstance(id_matches, str):
-        id_matches = id_matches.strip()
-        id_matches = id_matches.split(",")
+            next_tag = pstate.next_tag()
 
-    if len(id_matches) > 1:
-        return AlternativeMatch([
-            parse_id_match(im) for im in id_matches])
+            if next_tag is _and and _PREC_AND > min_precedence:
+                pstate.advance()
+                left_query = AndMatchExpression(
+                        (left_query, inner_parse(pstate, _PREC_AND)))
+                did_something = True
+            elif next_tag is _or and _PREC_OR > min_precedence:
+                pstate.advance()
+                left_query = OrMatchExpression(
+                        (left_query, inner_parse(pstate, _PREC_OR)))
+                did_something = True
 
-    if len(id_matches) == 0:
-        return AllMatch()
+        return left_query
 
-    id_match, = id_matches
-    del id_matches
+    from pytools.lex import LexIterator, lex
+    pstate = LexIterator(
+        [(tag, s, idx, matchobj)
+         for (tag, s, idx, matchobj) in lex(_LEX_TABLE, expr_str, match_objects=True)
+         if tag is not _whitespace], expr_str)
 
-    def re_from_glob(s):
-        import re
-        from fnmatch import translate
-        return re.compile("^"+translate(s.strip())+"$")
+    if pstate.is_at_end():
+        pstate.raise_parse_error("unexpected end of input")
 
-    if not isinstance(id_match, tuple):
-        components = id_match.split("$")
+    result = inner_parse(pstate)
+    if not pstate.is_at_end():
+        pstate.raise_parse_error("leftover input after completed parse")
 
-    if len(components) == 1:
-        return RegexIdentifierMatch(re_from_glob(components[0]))
-    elif len(components) == 2:
-        return RegexIdentifierMatch(
-                re_from_glob(components[0]),
-                re_from_glob(components[1]))
-    else:
-        raise RuntimeError("too many (%d) $-separated components in id match"
-                % len(components))
+    return result
 
 # }}}
 
+# }}}
 
-# {{{ stack match objects
 
-# these match from the tail of the stack
+# {{{ stack match objects
 
-class StackMatchBase(object):
+class StackMatchComponent(object):
     pass
 
 
-class AllStackMatch(StackMatchBase):
-    def __call__(self, stack):
+class StackAllMatchComponent(StackMatchComponent):
+    def __call__(self, kernel, stack):
         return True
 
 
-class StackIdMatch(StackMatchBase):
-    def __init__(self, id_match, up_match):
-        self.id_match = id_match
-        self.up_match = up_match
+class StackBottomMatchComponent(StackMatchComponent):
+    def __call__(self, kernel, stack):
+        return not stack
+
 
-    def __call__(self, stack):
+class StackItemMatchComponent(StackMatchComponent):
+    def __init__(self, match_expr, inner_match):
+        self.match_expr = match_expr
+        self.inner_match = inner_match
+
+    def __call__(self, kernel, stack):
         if not stack:
             return False
 
-        last = stack[-1]
-        if not self.id_match(*last):
+        outer = stack[0]
+        if not self.match_expr(kernel, outer):
             return False
 
-        if self.up_match is None:
-            return True
-        else:
-            return self.up_match(stack[:-1])
+        return self.inner_match(kernel, stack[1:])
 
 
-class StackWildcardMatch(StackMatchBase):
-    def __init__(self, up_match):
-        self.up_match = up_match
+class StackWildcardMatchComponent(StackMatchComponent):
+    def __init__(self, inner_match):
+        self.inner_match = inner_match
 
-    def __call__(self, stack):
-        if self.up_match is None:
-            return True
+    def __call__(self, kernel, stack):
+        for i in range(0, len(stack)):
+            if self.inner_match(kernel, stack[i:]):
+                return True
 
-        n = len(stack)
+        return False
 
-        if self.up_match(stack):
-            return True
+# }}}
 
-        for i in range(1, n):
-            if self.up_match(stack[:-i]):
-                return True
 
-        return False
+# {{{ stack matcher
+
+class RuleInvocationMatchable(object):
+    def __init__(self, id, tags):
+        self.id = id
+        self.tags = tags
+
+    def write_dependency_names(self):
+        raise TypeError("writes: query may not be applied to rule invocations")
+
+    def read_dependency_names(self):
+        raise TypeError("reads: query may not be applied to rule invocations")
+
+    def inames(self, kernel):
+        raise TypeError("inames: query may not be applied to rule invocations")
+
+
+class StackMatch(object):
+    def __init__(self, root_component):
+        self.root_component = root_component
+
+    def __call__(self, kernel, insn, rule_stack):
+        """
+        :arg rule_stack: a tuple of (name, tags) rule invocation, outermost first
+        """
+        stack_of_matchables = [insn]
+        for id, tags in rule_stack:
+            stack_of_matchables.append(RuleInvocationMatchable(id, tags))
+
+        return self.root_component(kernel, stack_of_matchables)
 
 # }}}
 
@@ -180,34 +355,41 @@ class StackWildcardMatch(StackMatchBase):
 def parse_stack_match(smatch):
     """Syntax example::
 
-        lowest < next < ... < highest
+        ... > outer > ... > next > innermost $
+        insn > next
+        insn > ... > next > innermost $
 
-    where `lowest` is necessarily the bottom of the stack.  ``...`` matches an
-    arbitrary number of intervening stack levels. There is currently no way to
-    match the top of the stack.
+    ``...`` matches an arbitrary number of intervening stack levels.
 
-    Each of the entries is an identifier match as understood by
-    :func:`parse_id_match`.
+    Each of the entries is a match expression as understood by
+    :func:`parse_match`.
     """
 
-    if isinstance(smatch, StackMatchBase):
+    if isinstance(smatch, StackMatch):
         return smatch
 
-    match = AllStackMatch()
-
     if smatch is None:
-        return match
+        return StackMatch(StackAllMatchComponent())
+
+    smatch = smatch.strip()
+
+    match = StackAllMatchComponent()
+    if smatch[-1] == "$":
+        match = StackBottomMatchComponent()
+        smatch = smatch[:-1]
+
+    smatch = smatch.strip()
 
-    components = smatch.split("<")
+    components = smatch.split(">")
 
     for comp in components[::-1]:
         comp = comp.strip()
         if comp == "...":
-            match = StackWildcardMatch(match)
+            match = StackWildcardMatchComponent(match)
         else:
-            match = StackIdMatch(parse_id_match(comp), match)
+            match = StackItemMatchComponent(parse_match(comp), match)
 
-    return match
+    return StackMatch(match)
 
 # }}}
 
diff --git a/loopy/fusion.py b/loopy/fusion.py
index 4431c2c7f..c14d936af 100644
--- a/loopy/fusion.py
+++ b/loopy/fusion.py
@@ -170,11 +170,13 @@ def _fuse_two_kernels(knla, knlb):
             SubstitutionRuleMappingContext,
             RuleAwareSubstitutionMapper)
     from pymbolic.mapper.substitutor import make_subst_func
+    from loopy.context_matching import parse_stack_match
 
     srmc = SubstitutionRuleMappingContext(
             knlb.substitutions, knlb.get_var_name_generator())
     subst_map = RuleAwareSubstitutionMapper(
-            srmc, make_subst_func(b_var_renames), within=lambda stack: True)
+            srmc, make_subst_func(b_var_renames),
+            within=parse_stack_match(None))
     knlb = subst_map.map_kernel(knlb)
 
     # }}}
diff --git a/loopy/precompute.py b/loopy/precompute.py
index 726cc0786..1227082a9 100644
--- a/loopy/precompute.py
+++ b/loopy/precompute.py
@@ -81,7 +81,10 @@ class RuleInvocationGatherer(RuleAwareIdentityMapper):
         if self.subst_tag is not None and self.subst_tag != tag:
             process_me = False
 
-        process_me = process_me and self.within(expn_state.stack)
+        process_me = process_me and self.within(
+                expn_state.kernel,
+                expn_state.instruction,
+                expn_state.stack)
 
         if not process_me:
             return super(RuleInvocationGatherer, self).map_substitution(
@@ -151,7 +154,10 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper):
     def map_substitution(self, name, tag, arguments, expn_state):
         if not (
                 name == self.subst_name
-                and self.within(expn_state.stack)
+                and self.within(
+                    expn_state.kernel,
+                    expn_state.instruction,
+                    expn_state.stack)
                 and (self.subst_tag is None or self.subst_tag == tag)):
             return super(RuleInvocationReplacer, self).map_substitution(
                     name, tag, arguments, expn_state)
@@ -387,8 +393,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
         import loopy as lp
         for insn in kernel.instructions:
             if isinstance(insn, lp.ExpressionInstruction):
-                invg(insn.assignee, insn.id, insn.tags)
-                invg(insn.expression, insn.id, insn.tags)
+                invg(insn.assignee, kernel, insn)
+                invg(insn.expression, kernel, insn)
 
         access_descriptors = invg.access_descriptors
         if not access_descriptors:
@@ -614,13 +620,13 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None,
     rule_mapping_context = SubstitutionRuleMappingContext(
             kernel.substitutions, kernel.get_var_name_generator())
 
-    from loopy.context_matching import AllStackMatch
+    from loopy.context_matching import parse_stack_match
     expr_subst_map = RuleAwareSubstitutionMapper(
             rule_mapping_context,
             make_subst_func(storage_axis_subst_dict),
-            within=AllStackMatch())
+            within=parse_stack_match(None))
 
-    compute_expression = expr_subst_map(subst.expression, None, None)
+    compute_expression = expr_subst_map(subst.expression, kernel, None)
 
     # }}}
 
diff --git a/loopy/subst.py b/loopy/subst.py
index 3b112a4fd..a0a031718 100644
--- a/loopy/subst.py
+++ b/loopy/subst.py
@@ -258,7 +258,10 @@ class TemporaryToSubstChanger(RuleAwareIdentityMapper):
 
         my_def_id = self.usage_to_definition[my_insn_id]
 
-        if not self.within(expn_state.stack):
+        if not self.within(
+                expn_state.kernel,
+                expn_state.instruction,
+                expn_state.stack):
             self.saw_unmatched_usage_sites[my_def_id] = True
             return None
 
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index bad7f840f..3311d2316 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -398,11 +398,13 @@ def parse_tagged_name(expr):
 
 class ExpansionState(Record):
     """
+    .. attribute:: kernel
+    .. attribute:: instruction
+
     .. attribute:: stack
 
         a tuple representing the current expansion stack, as a tuple
-        of (name, tag) pairs. At the top level, this should be initialized to a
-        tuple with the id of the calling instruction.
+        of (name, tag) pairs.
 
     .. attribute:: arg_context
 
@@ -411,7 +413,7 @@ class ExpansionState(Record):
 
     @property
     def insn_id(self):
-        return self.stack[0][0]
+        return self.instruction.id
 
     def apply_arg_context(self, expr):
         from pymbolic.mapper.substitutor import make_subst_func
@@ -625,16 +627,18 @@ class RuleAwareIdentityMapper(IdentityMapper):
         else:
             return sym
 
-    def __call__(self, expr, insn_id, insn_tags):
-        if insn_id is not None:
-            stack = ((insn_id, insn_tags),)
-        else:
-            stack = ()
+    def __call__(self, expr, kernel, insn):
+        from loopy.kernel.data import InstructionBase
+        assert insn is None or isinstance(insn, InstructionBase)
 
-        return IdentityMapper.__call__(self, expr, ExpansionState(
-            stack=stack, arg_context={}))
+        return IdentityMapper.__call__(self, expr,
+                ExpansionState(
+                    kernel=kernel,
+                    instruction=insn,
+                    stack=(),
+                    arg_context={}))
 
-    def map_instruction(self, insn):
+    def map_instruction(self, kernel, insn):
         return insn
 
     def map_kernel(self, kernel):
@@ -642,8 +646,8 @@ class RuleAwareIdentityMapper(IdentityMapper):
                 # While subst rules are not allowed in assignees, the mapper
                 # may perform tasks entirely unrelated to subst rules, so
                 # we must map assignees, too.
-                self.map_instruction(
-                    insn.with_transformed_expressions(self, insn.id, insn.tags))
+                self.map_instruction(kernel,
+                    insn.with_transformed_expressions(self, kernel, insn))
                 for insn in kernel.instructions]
 
         return kernel.copy(instructions=new_insns)
@@ -658,7 +662,8 @@ class RuleAwareSubstitutionMapper(RuleAwareIdentityMapper):
 
     def map_variable(self, expr, expn_state):
         if (expr.name in expn_state.arg_context
-                or not self.within(expn_state.stack)):
+                or not self.within(
+                    expn_state.kernel, expn_state.instruction, expn_state.stack)):
             return super(RuleAwareSubstitutionMapper, self).map_variable(
                     expr, expn_state)
 
@@ -685,7 +690,7 @@ class RuleAwareSubstitutionRuleExpander(RuleAwareIdentityMapper):
 
         new_stack = expn_state.stack + ((name, tags),)
 
-        if self.within(new_stack):
+        if self.within(expn_state.kernel, expn_state.instruction, new_stack):
             # expand
             rule = self.rules[name]
 
diff --git a/test/test_fortran.py b/test/test_fortran.py
index d361b15dc..f49355044 100644
--- a/test/test_fortran.py
+++ b/test/test_fortran.py
@@ -267,7 +267,7 @@ def test_tagged(ctx_factory):
 
     knl, = lp.parse_fortran(fortran_src)
 
-    assert sum(1 for insn in lp.find_instructions(knl, "*$input")) == 2
+    assert sum(1 for insn in lp.find_instructions(knl, "tag:input")) == 2
 
 
 @pytest.mark.parametrize("buffer_inames", [
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 2173347ca..1527e4ff7 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -65,7 +65,7 @@ def test_complicated_subst(ctx_factory):
                 a[i] = h$one(i) * h$two(i)
                 """)
 
-    knl = lp.expand_subst(knl, "g$two < h$two")
+    knl = lp.expand_subst(knl, "... > id:h and tag:two > id:g and tag:two")
 
     print(knl)
 
diff --git a/test/test_sem_reagan.py b/test/test_sem_reagan.py
index 33d15f88d..a00fce177 100644
--- a/test/test_sem_reagan.py
+++ b/test/test_sem_reagan.py
@@ -39,7 +39,7 @@ def test_tim2d(ctx_factory):
     n = 8
 
     from pymbolic import var
-    K_sym = var("K")
+    K_sym = var("K")  # noqa
 
     field_shape = (K_sym, n, n)
 
@@ -70,8 +70,8 @@ def test_tim2d(ctx_factory):
                 ],
             name="semlap2D", assumptions="K>=1")
 
-    knl = lp.duplicate_inames(knl, "o", within="ur")
-    knl = lp.duplicate_inames(knl, "o", within="us")
+    knl = lp.duplicate_inames(knl, "o", within="id:ur")
+    knl = lp.duplicate_inames(knl, "o", within="id:us")
 
     seq_knl = knl
 
@@ -93,13 +93,13 @@ def test_tim2d(ctx_factory):
         knl = lp.tag_inames(knl, dict(o="unr"))
         knl = lp.tag_inames(knl, dict(m="unr"))
 
-        knl = lp.set_instruction_priority(knl, "D_fetch", 5)
+        knl = lp.set_instruction_priority(knl, "id:D_fetch", 5)
         print(knl)
 
         return knl
 
     for variant in [variant_orig]:
-        K = 1000
+        K = 1000  # noqa
         lp.auto_test_vs_ref(seq_knl, ctx, variant(knl),
                 op_count=[K*(n*n*n*2*2 + n*n*2*3 + n**3 * 2*2)/1e9],
                 op_label=["GFlops"],
-- 
GitLab