From 90597d80b55f35f48492fece69b25ee19036733f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 22 Jun 2015 22:20:42 -0500 Subject: [PATCH] Implement kernel query language --- doc/reference.rst | 2 +- loopy/__init__.py | 63 ++++--- loopy/buffer.py | 12 +- loopy/context_matching.py | 388 ++++++++++++++++++++++++++++---------- loopy/fusion.py | 4 +- loopy/precompute.py | 20 +- loopy/subst.py | 5 +- loopy/symbolic.py | 35 ++-- test/test_fortran.py | 2 +- test/test_loopy.py | 2 +- test/test_sem_reagan.py | 10 +- 11 files changed, 384 insertions(+), 159 deletions(-) diff --git a/doc/reference.rst b/doc/reference.rst index 6ed83d430..c7435bbf7 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -356,7 +356,7 @@ TODO: Matching instruction tags .. automodule:: loopy.context_matching -.. autofunction:: parse_id_match +.. autofunction:: parse_match .. autofunction:: parse_stack_match diff --git a/loopy/__init__.py b/loopy/__init__.py index dcd8e27c1..58029467c 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -152,7 +152,10 @@ class _InameSplitter(RuleAwareIdentityMapper): def map_reduction(self, expr, expn_state): if (self.split_iname in expr.inames and self.split_iname not in expn_state.arg_context - and self.within(expn_state.stack)): + and self.within( + expn_state.kernel, + expn_state.instruction, + expn_state.stack)): new_inames = list(expr.inames) new_inames.remove(self.split_iname) new_inames.extend([self.outer_iname, self.inner_iname]) @@ -166,7 +169,10 @@ class _InameSplitter(RuleAwareIdentityMapper): def map_variable(self, expr, expn_state): if (expr.name == self.split_iname and self.split_iname not in expn_state.arg_context - and self.within(expn_state.stack)): + and self.within( + expn_state.kernel, + expn_state.instruction, + expn_state.stack)): return self.replacement_index else: return super(_InameSplitter, self).map_variable(expr, expn_state) @@ -318,7 +324,10 @@ class _InameJoiner(RuleAwareSubstitutionMapper): expr_inames = set(expr.inames) overlap = (self.join_inames & expr_inames - set(expn_state.arg_context)) - if overlap and self.within(expn_state.stack): + if overlap and self.within( + expn_state.kernel, + expn_state.instruction, + expn_state.stack): if overlap != expr_inames: raise LoopyError( "Cannot join inames '%s' if there is a reduction " @@ -520,7 +529,10 @@ class _InameDuplicator(RuleAwareIdentityMapper): def map_reduction(self, expr, expn_state): if (set(expr.inames) & self.old_inames_set - and self.within(expn_state.stack)): + and self.within( + expn_state.kernel, + expn_state.instruction, + expn_state.stack)): new_inames = tuple( self.old_to_new.get(iname, iname) if iname not in expn_state.arg_context @@ -538,14 +550,17 @@ class _InameDuplicator(RuleAwareIdentityMapper): if (new_name is None or expr.name in expn_state.arg_context - or not self.within(expn_state.stack)): + or not self.within( + expn_state.kernel, + expn_state.instruction, + expn_state.stack)): return super(_InameDuplicator, self).map_variable(expr, expn_state) else: from pymbolic import var return var(new_name) - def map_instruction(self, insn): - if not self.within(((insn.id, insn.tags),)): + def map_instruction(self, kernel, insn): + if not self.within(kernel, insn, ()): return insn new_fid = frozenset( @@ -1050,7 +1065,7 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None, # If the rule survived past precompute() (i.e. some accesses fell outside # the footprint), get rid of it before moving on. if rule_name in new_kernel.substitutions: - return expand_subst(new_kernel, rule_name) + return expand_subst(new_kernel, "id:"+rule_name) else: return new_kernel @@ -1060,19 +1075,19 @@ def add_prefetch(kernel, var_name, sweep_inames=[], dim_arg_names=None, # {{{ instruction processing def find_instructions(kernel, insn_match): - from loopy.context_matching import parse_id_match - match = parse_id_match(insn_match) - return [insn for insn in kernel.instructions if match(insn.id, insn.tags)] + from loopy.context_matching import parse_match + match = parse_match(insn_match) + return [insn for insn in kernel.instructions if match(kernel, insn)] def map_instructions(kernel, insn_match, f): - from loopy.context_matching import parse_id_match - match = parse_id_match(insn_match) + from loopy.context_matching import parse_match + match = parse_match(insn_match) new_insns = [] for insn in kernel.instructions: - if match(insn.id, None): + if match(kernel, insn): new_insns.append(f(insn)) else: new_insns.append(insn) @@ -1084,7 +1099,7 @@ def set_instruction_priority(kernel, insn_match, priority): """Set the priority of instructions matching *insn_match* to *priority*. *insn_match* may be any instruction id match understood by - :func:`loopy.context_matching.parse_id_match`. + :func:`loopy.context_matching.parse_match`. """ def set_prio(insn): @@ -1098,7 +1113,7 @@ def add_dependency(kernel, insn_match, dependency): by *insn_match*. *insn_match* may be any instruction id match understood by - :func:`loopy.context_matching.parse_id_match`. + :func:`loopy.context_matching.parse_match`. """ def add_dep(insn): @@ -1220,7 +1235,11 @@ class _ReductionSplitter(RuleAwareIdentityMapper): # FIXME raise NotImplementedError() - if self.inames <= set(expr.inames) and self.within(expn_state.stack): + if (self.inames <= set(expr.inames) + and self.within( + expn_state.kernel, + expn_state.instruction, + expn_state.stack)): leftover_inames = set(expr.inames) - self.inames from loopy.symbolic import Reduction @@ -1659,10 +1678,12 @@ def affine_map_inames(kernel, old_inames, new_inames, equations): var_name_gen = kernel.get_var_name_generator() from pymbolic.mapper.substitutor import make_subst_func + from loopy.context_matching import parse_stack_match + rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, var_name_gen) old_to_new = RuleAwareSubstitutionMapper(rule_mapping_context, - make_subst_func(subst_dict), within=lambda stack: True) + make_subst_func(subst_dict), within=parse_stack_match(None)) kernel = ( rule_mapping_context.finish_kernel( @@ -1792,12 +1813,12 @@ def fold_constants(kernel): # {{{ tag_instructions def tag_instructions(kernel, new_tag, within=None): - from loopy.context_matching import parse_stack_match - within = parse_stack_match(within) + from loopy.context_matching import parse_match + within = parse_match(within) new_insns = [] for insn in kernel.instructions: - if within(((insn.id, insn.tags),)): + if within(kernel, insn): new_insns.append( insn.copy(tags=insn.tags + (new_tag,))) else: diff --git a/loopy/buffer.py b/loopy/buffer.py index d155dba7e..fdc3774b2 100644 --- a/loopy/buffer.py +++ b/loopy/buffer.py @@ -51,7 +51,10 @@ class ArrayAccessReplacer(RuleAwareIdentityMapper): def map_variable(self, expr, expn_state): result = None - if expr.name == self.var_name and self.within(expn_state): + if expr.name == self.var_name and self.within( + expn_state.kernel, + expn_state.instruction, + expn_state.stack): result = self.map_array_access((), expn_state) if result is None: @@ -62,7 +65,10 @@ class ArrayAccessReplacer(RuleAwareIdentityMapper): def map_subscript(self, expr, expn_state): result = None - if expr.aggregate.name == self.var_name and self.within(expn_state): + if expr.aggregate.name == self.var_name and self.within( + expn_state.kernel, + expn_state.instruction, + expn_state.stack): result = self.map_array_access(expr.index, expn_state) if result is None: @@ -172,7 +178,7 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, access_descriptors = [] for insn in kernel.instructions: - if not within((insn.id, insn.tags)): + if not within(kernel, insn.id, ()): continue for assignee, index in insn.assignees_and_indices(): diff --git a/loopy/context_matching.py b/loopy/context_matching.py index b259a0ddd..cb0d6e956 100644 --- a/loopy/context_matching.py +++ b/loopy/context_matching.py @@ -29,148 +29,323 @@ THE SOFTWARE. NoneType = type(None) +from pytools.lex import RE -# {{{ id match objects -class AllMatch(object): - def __call__(self, identifier, tag): +def re_from_glob(s): + import re + from fnmatch import translate + return re.compile("^"+translate(s.strip())+"$") + +# {{{ parsing + +# {{{ lexer data + +_and = intern("and") +_or = intern("or") +_not = intern("not") +_openpar = intern("openpar") +_closepar = intern("closepar") + +_id = intern("_id") +_tag = intern("_tag") +_writes = intern("_writes") +_reads = intern("_reads") +_iname = intern("_reads") + +_whitespace = intern("_whitespace") + +# }}} + + +_LEX_TABLE = [ + (_and, RE(r"and\b")), + (_or, RE(r"or\b")), + (_not, RE(r"not\b")), + (_openpar, RE(r"\(")), + (_closepar, RE(r"\)")), + + # TERMINALS + (_id, RE(r"id:([\w?*]+)")), + (_tag, RE(r"tag:([\w?*]+)")), + (_writes, RE(r"writes:([\w?*]+)")), + (_reads, RE(r"reads:([\w?*]+)")), + (_iname, RE(r"iname:([\w?*]+)")), + + (_whitespace, RE("[ \t]+")), + ] + + +_TERMINALS = ([_id, _tag, _writes, _reads, _iname]) + +# {{{ operator precedence + +_PREC_OR = 10 +_PREC_AND = 20 +_PREC_NOT = 30 + +# }}} + + +# {{{ match expression + +class MatchExpressionBase(object): + def __call__(self, kernel, matchable): + raise NotImplementedError + + +class AllMatchExpression(MatchExpressionBase): + def __call__(self, kernel, matchable): return True -class RegexIdentifierMatch(object): - def __init__(self, id_re, tag_re=None): - self.id_re = id_re - self.tag_re = tag_re +class AndMatchExpression(MatchExpressionBase): + def __init__(self, children): + self.children = children + + def __call__(self, kernel, matchable): + return all(ch(kernel, matchable) for ch in self.children) + + def __str__(self): + return "(%s)" % (" and ".join(str(ch) for ch in self.children)) + + +class OrMatchExpression(MatchExpressionBase): + def __init__(self, children): + self.children = children + + def __call__(self, kernel, matchable): + return any(ch(kernel, matchable) for ch in self.children) - def __call__(self, identifier, tags): - assert isinstance(tags, (tuple, NoneType)) + def __str__(self): + return "(%s)" % (" or ".join(str(ch) for ch in self.children)) + + +class NotMatchExpression(MatchExpressionBase): + def __init__(self, child): + self.child = child + + def __call__(self, kernel, matchable): + return not self.child(kernel, matchable) + + def __str__(self): + return "(not %s)" % str(self.child) + + +class GlobMatchExpressionBase(MatchExpressionBase): + def __init__(self, glob): + self.glob = glob + + import re + from fnmatch import translate + self.re = re.compile("^"+translate(glob.strip())+"$") - if self.tag_re is None: - return self.id_re.match(identifier) is not None + def __str__(self): + descr = type(self).__name__ + descr = descr[:descr.find("Match")] + return descr.lower() + ":" + self.glob + + +class IdMatchExpression(GlobMatchExpressionBase): + def __call__(self, kernel, matchable): + return self.re.match(matchable.id) + + +class TagMatchExpression(GlobMatchExpressionBase): + def __call__(self, kernel, matchable): + if matchable.tags: + return any(self.re.match(tag) for tag in matchable.tags) else: - if not tags: - tags = ("",) + return False - return ( - self.id_re.match(identifier) is not None - and any( - self.tag_re.match(tag) is not None - for tag in tags)) +class WritesMatchExpression(GlobMatchExpressionBase): + def __call__(self, kernel, matchable): + return any(self.re.match(name) + for name in matchable.write_dependency_names()) -class AlternativeMatch(object): - def __init__(self, matches): - self.matches = matches - def __call__(self, identifier, tags): - from pytools import any - return any( - mtch(identifier, tags) for mtch in self.matches) +class ReadsMatchExpression(GlobMatchExpressionBase): + def __call__(self, kernel, matchable): + return any(self.re.match(name) + for name in matchable.read_dependency_names()) + + +class InameMatchExpression(GlobMatchExpressionBase): + def __call__(self, kernel, matchable): + return any(self.re.match(name) + for name in matchable.inames(kernel)) # }}} -# {{{ single id match parsing +# {{{ parser -def parse_id_match(id_matches): +def parse_match(expr_str): """Syntax examples:: - my_insn - compute_* - fetch*$first - fetch*$first,store*$first - - Alternatively, a list of *(name_glob, tag_glob)* tuples. + * ``id:yoink and writes:a_temp`` + * ``id:yoink and (not writes:a_temp or tagged:input)`` """ + if not expr_str: + return AllMatchExpression() + + def parse_terminal(pstate): + next_tag = pstate.next_tag() + if next_tag is _id: + result = IdMatchExpression(pstate.next_match_obj().group(1)) + pstate.advance() + return result + elif next_tag is _tag: + result = TagMatchExpression(pstate.next_match_obj().group(1)) + pstate.advance() + return result + elif next_tag is _writes: + result = WritesMatchExpression(pstate.next_match_obj().group(1)) + pstate.advance() + return result + elif next_tag is _reads: + result = ReadsMatchExpression(pstate.next_match_obj().group(1)) + pstate.advance() + return result + elif next_tag is _iname: + result = InameMatchExpression(pstate.next_match_obj().group(1)) + pstate.advance() + return result + else: + pstate.expected("terminal") + + def inner_parse(pstate, min_precedence=0): + pstate.expect_not_end() + + if pstate.is_next(_not): + pstate.advance() + left_query = NotMatchExpression(inner_parse(pstate, _PREC_NOT)) + elif pstate.is_next(_openpar): + pstate.advance() + left_query = inner_parse(pstate) + pstate.expect(_closepar) + pstate.advance() + else: + left_query = parse_terminal(pstate) - if id_matches is None: - return AllMatch() + did_something = True + while did_something: + did_something = False + if pstate.is_at_end(): + return left_query - if isinstance(id_matches, str): - id_matches = id_matches.strip() - id_matches = id_matches.split(",") + next_tag = pstate.next_tag() - if len(id_matches) > 1: - return AlternativeMatch([ - parse_id_match(im) for im in id_matches]) + if next_tag is _and and _PREC_AND > min_precedence: + pstate.advance() + left_query = AndMatchExpression( + (left_query, inner_parse(pstate, _PREC_AND))) + did_something = True + elif next_tag is _or and _PREC_OR > min_precedence: + pstate.advance() + left_query = OrMatchExpression( + (left_query, inner_parse(pstate, _PREC_OR))) + did_something = True - if len(id_matches) == 0: - return AllMatch() + return left_query - id_match, = id_matches - del id_matches + from pytools.lex import LexIterator, lex + pstate = LexIterator( + [(tag, s, idx, matchobj) + for (tag, s, idx, matchobj) in lex(_LEX_TABLE, expr_str, match_objects=True) + if tag is not _whitespace], expr_str) - def re_from_glob(s): - import re - from fnmatch import translate - return re.compile("^"+translate(s.strip())+"$") + if pstate.is_at_end(): + pstate.raise_parse_error("unexpected end of input") - if not isinstance(id_match, tuple): - components = id_match.split("$") + result = inner_parse(pstate) + if not pstate.is_at_end(): + pstate.raise_parse_error("leftover input after completed parse") - if len(components) == 1: - return RegexIdentifierMatch(re_from_glob(components[0])) - elif len(components) == 2: - return RegexIdentifierMatch( - re_from_glob(components[0]), - re_from_glob(components[1])) - else: - raise RuntimeError("too many (%d) $-separated components in id match" - % len(components)) + return result # }}} +# }}} -# {{{ stack match objects -# these match from the tail of the stack +# {{{ stack match objects -class StackMatchBase(object): +class StackMatchComponent(object): pass -class AllStackMatch(StackMatchBase): - def __call__(self, stack): +class StackAllMatchComponent(StackMatchComponent): + def __call__(self, kernel, stack): return True -class StackIdMatch(StackMatchBase): - def __init__(self, id_match, up_match): - self.id_match = id_match - self.up_match = up_match +class StackBottomMatchComponent(StackMatchComponent): + def __call__(self, kernel, stack): + return not stack + - def __call__(self, stack): +class StackItemMatchComponent(StackMatchComponent): + def __init__(self, match_expr, inner_match): + self.match_expr = match_expr + self.inner_match = inner_match + + def __call__(self, kernel, stack): if not stack: return False - last = stack[-1] - if not self.id_match(*last): + outer = stack[0] + if not self.match_expr(kernel, outer): return False - if self.up_match is None: - return True - else: - return self.up_match(stack[:-1]) + return self.inner_match(kernel, stack[1:]) -class StackWildcardMatch(StackMatchBase): - def __init__(self, up_match): - self.up_match = up_match +class StackWildcardMatchComponent(StackMatchComponent): + def __init__(self, inner_match): + self.inner_match = inner_match - def __call__(self, stack): - if self.up_match is None: - return True + def __call__(self, kernel, stack): + for i in range(0, len(stack)): + if self.inner_match(kernel, stack[i:]): + return True - n = len(stack) + return False - if self.up_match(stack): - return True +# }}} - for i in range(1, n): - if self.up_match(stack[:-i]): - return True - return False +# {{{ stack matcher + +class RuleInvocationMatchable(object): + def __init__(self, id, tags): + self.id = id + self.tags = tags + + def write_dependency_names(self): + raise TypeError("writes: query may not be applied to rule invocations") + + def read_dependency_names(self): + raise TypeError("reads: query may not be applied to rule invocations") + + def inames(self, kernel): + raise TypeError("inames: query may not be applied to rule invocations") + + +class StackMatch(object): + def __init__(self, root_component): + self.root_component = root_component + + def __call__(self, kernel, insn, rule_stack): + """ + :arg rule_stack: a tuple of (name, tags) rule invocation, outermost first + """ + stack_of_matchables = [insn] + for id, tags in rule_stack: + stack_of_matchables.append(RuleInvocationMatchable(id, tags)) + + return self.root_component(kernel, stack_of_matchables) # }}} @@ -180,34 +355,41 @@ class StackWildcardMatch(StackMatchBase): def parse_stack_match(smatch): """Syntax example:: - lowest < next < ... < highest + ... > outer > ... > next > innermost $ + insn > next + insn > ... > next > innermost $ - where `lowest` is necessarily the bottom of the stack. ``...`` matches an - arbitrary number of intervening stack levels. There is currently no way to - match the top of the stack. + ``...`` matches an arbitrary number of intervening stack levels. - Each of the entries is an identifier match as understood by - :func:`parse_id_match`. + Each of the entries is a match expression as understood by + :func:`parse_match`. """ - if isinstance(smatch, StackMatchBase): + if isinstance(smatch, StackMatch): return smatch - match = AllStackMatch() - if smatch is None: - return match + return StackMatch(StackAllMatchComponent()) + + smatch = smatch.strip() + + match = StackAllMatchComponent() + if smatch[-1] == "$": + match = StackBottomMatchComponent() + smatch = smatch[:-1] + + smatch = smatch.strip() - components = smatch.split("<") + components = smatch.split(">") for comp in components[::-1]: comp = comp.strip() if comp == "...": - match = StackWildcardMatch(match) + match = StackWildcardMatchComponent(match) else: - match = StackIdMatch(parse_id_match(comp), match) + match = StackItemMatchComponent(parse_match(comp), match) - return match + return StackMatch(match) # }}} diff --git a/loopy/fusion.py b/loopy/fusion.py index 4431c2c7f..c14d936af 100644 --- a/loopy/fusion.py +++ b/loopy/fusion.py @@ -170,11 +170,13 @@ def _fuse_two_kernels(knla, knlb): SubstitutionRuleMappingContext, RuleAwareSubstitutionMapper) from pymbolic.mapper.substitutor import make_subst_func + from loopy.context_matching import parse_stack_match srmc = SubstitutionRuleMappingContext( knlb.substitutions, knlb.get_var_name_generator()) subst_map = RuleAwareSubstitutionMapper( - srmc, make_subst_func(b_var_renames), within=lambda stack: True) + srmc, make_subst_func(b_var_renames), + within=parse_stack_match(None)) knlb = subst_map.map_kernel(knlb) # }}} diff --git a/loopy/precompute.py b/loopy/precompute.py index 726cc0786..1227082a9 100644 --- a/loopy/precompute.py +++ b/loopy/precompute.py @@ -81,7 +81,10 @@ class RuleInvocationGatherer(RuleAwareIdentityMapper): if self.subst_tag is not None and self.subst_tag != tag: process_me = False - process_me = process_me and self.within(expn_state.stack) + process_me = process_me and self.within( + expn_state.kernel, + expn_state.instruction, + expn_state.stack) if not process_me: return super(RuleInvocationGatherer, self).map_substitution( @@ -151,7 +154,10 @@ class RuleInvocationReplacer(RuleAwareIdentityMapper): def map_substitution(self, name, tag, arguments, expn_state): if not ( name == self.subst_name - and self.within(expn_state.stack) + and self.within( + expn_state.kernel, + expn_state.instruction, + expn_state.stack) and (self.subst_tag is None or self.subst_tag == tag)): return super(RuleInvocationReplacer, self).map_substitution( name, tag, arguments, expn_state) @@ -387,8 +393,8 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, import loopy as lp for insn in kernel.instructions: if isinstance(insn, lp.ExpressionInstruction): - invg(insn.assignee, insn.id, insn.tags) - invg(insn.expression, insn.id, insn.tags) + invg(insn.assignee, kernel, insn) + invg(insn.expression, kernel, insn) access_descriptors = invg.access_descriptors if not access_descriptors: @@ -614,13 +620,13 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) - from loopy.context_matching import AllStackMatch + from loopy.context_matching import parse_stack_match expr_subst_map = RuleAwareSubstitutionMapper( rule_mapping_context, make_subst_func(storage_axis_subst_dict), - within=AllStackMatch()) + within=parse_stack_match(None)) - compute_expression = expr_subst_map(subst.expression, None, None) + compute_expression = expr_subst_map(subst.expression, kernel, None) # }}} diff --git a/loopy/subst.py b/loopy/subst.py index 3b112a4fd..a0a031718 100644 --- a/loopy/subst.py +++ b/loopy/subst.py @@ -258,7 +258,10 @@ class TemporaryToSubstChanger(RuleAwareIdentityMapper): my_def_id = self.usage_to_definition[my_insn_id] - if not self.within(expn_state.stack): + if not self.within( + expn_state.kernel, + expn_state.instruction, + expn_state.stack): self.saw_unmatched_usage_sites[my_def_id] = True return None diff --git a/loopy/symbolic.py b/loopy/symbolic.py index bad7f840f..3311d2316 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -398,11 +398,13 @@ def parse_tagged_name(expr): class ExpansionState(Record): """ + .. attribute:: kernel + .. attribute:: instruction + .. attribute:: stack a tuple representing the current expansion stack, as a tuple - of (name, tag) pairs. At the top level, this should be initialized to a - tuple with the id of the calling instruction. + of (name, tag) pairs. .. attribute:: arg_context @@ -411,7 +413,7 @@ class ExpansionState(Record): @property def insn_id(self): - return self.stack[0][0] + return self.instruction.id def apply_arg_context(self, expr): from pymbolic.mapper.substitutor import make_subst_func @@ -625,16 +627,18 @@ class RuleAwareIdentityMapper(IdentityMapper): else: return sym - def __call__(self, expr, insn_id, insn_tags): - if insn_id is not None: - stack = ((insn_id, insn_tags),) - else: - stack = () + def __call__(self, expr, kernel, insn): + from loopy.kernel.data import InstructionBase + assert insn is None or isinstance(insn, InstructionBase) - return IdentityMapper.__call__(self, expr, ExpansionState( - stack=stack, arg_context={})) + return IdentityMapper.__call__(self, expr, + ExpansionState( + kernel=kernel, + instruction=insn, + stack=(), + arg_context={})) - def map_instruction(self, insn): + def map_instruction(self, kernel, insn): return insn def map_kernel(self, kernel): @@ -642,8 +646,8 @@ class RuleAwareIdentityMapper(IdentityMapper): # While subst rules are not allowed in assignees, the mapper # may perform tasks entirely unrelated to subst rules, so # we must map assignees, too. - self.map_instruction( - insn.with_transformed_expressions(self, insn.id, insn.tags)) + self.map_instruction(kernel, + insn.with_transformed_expressions(self, kernel, insn)) for insn in kernel.instructions] return kernel.copy(instructions=new_insns) @@ -658,7 +662,8 @@ class RuleAwareSubstitutionMapper(RuleAwareIdentityMapper): def map_variable(self, expr, expn_state): if (expr.name in expn_state.arg_context - or not self.within(expn_state.stack)): + or not self.within( + expn_state.kernel, expn_state.instruction, expn_state.stack)): return super(RuleAwareSubstitutionMapper, self).map_variable( expr, expn_state) @@ -685,7 +690,7 @@ class RuleAwareSubstitutionRuleExpander(RuleAwareIdentityMapper): new_stack = expn_state.stack + ((name, tags),) - if self.within(new_stack): + if self.within(expn_state.kernel, expn_state.instruction, new_stack): # expand rule = self.rules[name] diff --git a/test/test_fortran.py b/test/test_fortran.py index d361b15dc..f49355044 100644 --- a/test/test_fortran.py +++ b/test/test_fortran.py @@ -267,7 +267,7 @@ def test_tagged(ctx_factory): knl, = lp.parse_fortran(fortran_src) - assert sum(1 for insn in lp.find_instructions(knl, "*$input")) == 2 + assert sum(1 for insn in lp.find_instructions(knl, "tag:input")) == 2 @pytest.mark.parametrize("buffer_inames", [ diff --git a/test/test_loopy.py b/test/test_loopy.py index 2173347ca..1527e4ff7 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -65,7 +65,7 @@ def test_complicated_subst(ctx_factory): a[i] = h$one(i) * h$two(i) """) - knl = lp.expand_subst(knl, "g$two < h$two") + knl = lp.expand_subst(knl, "... > id:h and tag:two > id:g and tag:two") print(knl) diff --git a/test/test_sem_reagan.py b/test/test_sem_reagan.py index 33d15f88d..a00fce177 100644 --- a/test/test_sem_reagan.py +++ b/test/test_sem_reagan.py @@ -39,7 +39,7 @@ def test_tim2d(ctx_factory): n = 8 from pymbolic import var - K_sym = var("K") + K_sym = var("K") # noqa field_shape = (K_sym, n, n) @@ -70,8 +70,8 @@ def test_tim2d(ctx_factory): ], name="semlap2D", assumptions="K>=1") - knl = lp.duplicate_inames(knl, "o", within="ur") - knl = lp.duplicate_inames(knl, "o", within="us") + knl = lp.duplicate_inames(knl, "o", within="id:ur") + knl = lp.duplicate_inames(knl, "o", within="id:us") seq_knl = knl @@ -93,13 +93,13 @@ def test_tim2d(ctx_factory): knl = lp.tag_inames(knl, dict(o="unr")) knl = lp.tag_inames(knl, dict(m="unr")) - knl = lp.set_instruction_priority(knl, "D_fetch", 5) + knl = lp.set_instruction_priority(knl, "id:D_fetch", 5) print(knl) return knl for variant in [variant_orig]: - K = 1000 + K = 1000 # noqa lp.auto_test_vs_ref(seq_knl, ctx, variant(knl), op_count=[K*(n*n*n*2*2 + n*n*2*3 + n**3 * 2*2)/1e9], op_label=["GFlops"], -- GitLab