diff --git a/doc/ref_kernel.rst b/doc/ref_kernel.rst index 253ccef3cc91b255f8c14b821257a8a80936daf6..d151a2128754373559ea8a45730609782adb33d7 100644 --- a/doc/ref_kernel.rst +++ b/doc/ref_kernel.rst @@ -183,7 +183,7 @@ These are usually key-value pairs. The following attributes are recognized: * ``id=value`` sets the instruction's identifier to ``value``. ``value`` must be unique within the kernel. This identifier is used to refer to the instruction after it has been created, such as from ``dep`` attributes - (see below) or from :mod:`context matches <loopy.context_matching>`. + (see below) or from :mod:`context matches <loopy.match>`. * ``id_prefix=value`` also sets the instruction's identifier, however uniqueness is ensured by loopy itself, by appending further components diff --git a/doc/ref_transform.rst b/doc/ref_transform.rst index 4c3c24873945dc0d87ec414ff7870fafd0a5bda1..f92cfbf67c0d2af723db8f6fa0ee67ebb14c08fe 100644 --- a/doc/ref_transform.rst +++ b/doc/ref_transform.rst @@ -124,12 +124,7 @@ Matching contexts TODO: Matching instruction tags -.. automodule:: loopy.context_matching - -.. autofunction:: parse_match - -.. autofunction:: parse_stack_match - +.. automodule:: loopy.match .. vim: tw=75:spell diff --git a/loopy/context_matching.py b/loopy/match.py similarity index 89% rename from loopy/context_matching.py rename to loopy/match.py index 45f9a4d74b353e5821d2a2ed3e410c44e3187eb8..d285d464351e39e16ecdf96280fbebb4f6246f8e 100644 --- a/loopy/context_matching.py +++ b/loopy/match.py @@ -32,6 +32,26 @@ NoneType = type(None) from pytools.lex import RE +__doc__ = """ +.. autofunction:: parse_match + +.. autofunction:: parse_stack_match + +Match expressions +^^^^^^^^^^^^^^^^^ + +.. autoclass:: MatchExpressionBase +.. autoclass:: All +.. autoclass:: And +.. autoclass:: Or +.. autoclass:: Not +.. autoclass:: Id +.. autoclass:: Tagged +.. autoclass:: Writes +.. autoclass:: Reads +.. autoclass:: Iname +""" + def re_from_glob(s): import re @@ -97,8 +117,17 @@ class MatchExpressionBase(object): def __ne__(self, other): return not self.__eq__(other) + def __and__(self, other): + return And((self, other)) + + def __or__(self, other): + return Or((self, other)) -class AllMatchExpression(MatchExpressionBase): + def __inv__(self): + return Not(self) + + +class All(MatchExpressionBase): def __call__(self, kernel, matchable): return True @@ -109,7 +138,7 @@ class AllMatchExpression(MatchExpressionBase): return (type(self) == type(other)) -class AndMatchExpression(MatchExpressionBase): +class And(MatchExpressionBase): def __init__(self, children): self.children = children @@ -128,7 +157,7 @@ class AndMatchExpression(MatchExpressionBase): and self.children == other.children) -class OrMatchExpression(MatchExpressionBase): +class Or(MatchExpressionBase): def __init__(self, children): self.children = children @@ -147,7 +176,7 @@ class OrMatchExpression(MatchExpressionBase): and self.children == other.children) -class NotMatchExpression(MatchExpressionBase): +class Not(MatchExpressionBase): def __init__(self, child): self.child = child @@ -176,7 +205,6 @@ class GlobMatchExpressionBase(MatchExpressionBase): def __str__(self): descr = type(self).__name__ - descr = descr[:descr.find("Match")] return descr.lower() + ":" + self.glob def update_persistent_hash(self, key_hash, key_builder): @@ -188,12 +216,12 @@ class GlobMatchExpressionBase(MatchExpressionBase): and self.glob == other.glob) -class IdMatchExpression(GlobMatchExpressionBase): +class Id(GlobMatchExpressionBase): def __call__(self, kernel, matchable): return self.re.match(matchable.id) -class TagMatchExpression(GlobMatchExpressionBase): +class Tagged(GlobMatchExpressionBase): def __call__(self, kernel, matchable): if matchable.tags: return any(self.re.match(tag) for tag in matchable.tags) @@ -201,19 +229,19 @@ class TagMatchExpression(GlobMatchExpressionBase): return False -class WritesMatchExpression(GlobMatchExpressionBase): +class Writes(GlobMatchExpressionBase): def __call__(self, kernel, matchable): return any(self.re.match(name) for name in matchable.write_dependency_names()) -class ReadsMatchExpression(GlobMatchExpressionBase): +class Reads(GlobMatchExpressionBase): def __call__(self, kernel, matchable): return any(self.re.match(name) for name in matchable.read_dependency_names()) -class InameMatchExpression(GlobMatchExpressionBase): +class Iname(GlobMatchExpressionBase): def __call__(self, kernel, matchable): return any(self.re.match(name) for name in matchable.inames(kernel)) @@ -223,35 +251,35 @@ class InameMatchExpression(GlobMatchExpressionBase): # {{{ parser -def parse_match(expr_str): +def parse_match(expr): """Syntax examples:: * ``id:yoink and writes:a_temp`` * ``id:yoink and (not writes:a_temp or tagged:input)`` """ - if not expr_str: - return AllMatchExpression() + if not expr: + return All() def parse_terminal(pstate): next_tag = pstate.next_tag() if next_tag is _id: - result = IdMatchExpression(pstate.next_match_obj().group(1)) + result = Id(pstate.next_match_obj().group(1)) pstate.advance() return result elif next_tag is _tag: - result = TagMatchExpression(pstate.next_match_obj().group(1)) + result = Tagged(pstate.next_match_obj().group(1)) pstate.advance() return result elif next_tag is _writes: - result = WritesMatchExpression(pstate.next_match_obj().group(1)) + result = Writes(pstate.next_match_obj().group(1)) pstate.advance() return result elif next_tag is _reads: - result = ReadsMatchExpression(pstate.next_match_obj().group(1)) + result = Reads(pstate.next_match_obj().group(1)) pstate.advance() return result elif next_tag is _iname: - result = InameMatchExpression(pstate.next_match_obj().group(1)) + result = Iname(pstate.next_match_obj().group(1)) pstate.advance() return result else: @@ -262,7 +290,7 @@ def parse_match(expr_str): if pstate.is_next(_not): pstate.advance() - left_query = NotMatchExpression(inner_parse(pstate, _PREC_NOT)) + left_query = Not(inner_parse(pstate, _PREC_NOT)) elif pstate.is_next(_openpar): pstate.advance() left_query = inner_parse(pstate) @@ -281,30 +309,33 @@ def parse_match(expr_str): if next_tag is _and and _PREC_AND > min_precedence: pstate.advance() - left_query = AndMatchExpression( + left_query = And( (left_query, inner_parse(pstate, _PREC_AND))) did_something = True elif next_tag is _or and _PREC_OR > min_precedence: pstate.advance() - left_query = OrMatchExpression( + left_query = Or( (left_query, inner_parse(pstate, _PREC_OR))) did_something = True return left_query + if isinstance(expr, MatchExpressionBase): + return expr + from pytools.lex import LexIterator, lex, InvalidTokenError try: pstate = LexIterator( [(tag, s, idx, matchobj) - for (tag, s, idx, matchobj) in lex(_LEX_TABLE, expr_str, + for (tag, s, idx, matchobj) in lex(_LEX_TABLE, expr, match_objects=True) - if tag is not _whitespace], expr_str) + if tag is not _whitespace], expr) except InvalidTokenError as e: from loopy.diagnostic import LoopyError raise LoopyError( "invalid match expression: '{match_expr}' ({err_type}: {err_str})" .format( - match_expr=expr_str, + match_expr=expr, err_type=type(e).__name__, err_str=str(e))) diff --git a/loopy/transform/buffer.py b/loopy/transform/buffer.py index 677de78eaa2944956e2d65209dc5716af5bb091a..002d5986a6f81a68a1c060937a2ebb7ab4821157 100644 --- a/loopy/transform/buffer.py +++ b/loopy/transform/buffer.py @@ -165,7 +165,7 @@ def buffer_array(kernel, var_name, buffer_inames, init_expression=None, buffer_inames = list(buffer_inames) buffer_inames_set = frozenset(buffer_inames) - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match within = parse_stack_match(within) if var_name in kernel.arg_dict: diff --git a/loopy/transform/fusion.py b/loopy/transform/fusion.py index e44f8abe227d451e8e940708530f6c20566685e8..a56c2903a3cf75938357119714a6745dc0d65ade 100644 --- a/loopy/transform/fusion.py +++ b/loopy/transform/fusion.py @@ -37,7 +37,7 @@ def _apply_renames_in_exprs(kernel, var_renames): SubstitutionRuleMappingContext, RuleAwareSubstitutionMapper) from pymbolic.mapper.substitutor import make_subst_func - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match srmc = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 4b575d2e3970c4f1ffcd66ce6021ba9940789fe5..32e718d51053000a6cc81a9df6e8b492a218d7f4 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -147,7 +147,7 @@ def _split_iname_backend(kernel, split_iname, within=None): """ :arg within: a stack match as understood by - :func:`loopy.context_matching.parse_stack_match`. + :func:`loopy.match.parse_stack_match`. """ existing_tag = kernel.iname_to_tag.get(split_iname) @@ -256,7 +256,7 @@ def _split_iname_backend(kernel, split_iname, applied_iname_rewrites=applied_iname_rewrites, loop_priority=new_loop_priority) - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match within = parse_stack_match(within) rule_mapping_context = SubstitutionRuleMappingContext( @@ -301,7 +301,7 @@ def split_iname(kernel, split_iname, inner_length, :arg inner_tag: The iname tag (see :ref:`iname-tags`) to apply to *inner_iname*. :arg within: a stack match as understood by - :func:`loopy.context_matching.parse_stack_match`. + :func:`loopy.match.parse_stack_match`. """ def make_new_loop_index(inner, outer): return inner + outer*inner_length @@ -330,7 +330,7 @@ def chunk_iname(kernel, split_iname, num_chunks, fixed length *num_chunks*. :arg within: a stack match as understood by - :func:`loopy.context_matching.parse_stack_match`. + :func:`loopy.match.parse_stack_match`. .. versionadded:: 2016.2 """ @@ -457,7 +457,7 @@ def join_inames(kernel, inames, new_iname=None, tag=None, within=None): """ :arg inames: fastest varying last :arg within: a stack match as understood by - :func:`loopy.context_matching.parse_stack_match`. + :func:`loopy.match.parse_stack_match`. """ # now fastest varying first @@ -548,7 +548,7 @@ def join_inames(kernel, inames, new_iname=None, tag=None, within=None): applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_dict] )) - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match within = parse_stack_match(within) from pymbolic.mapper.substitutor import make_subst_func @@ -713,7 +713,7 @@ def duplicate_inames(knl, inames, within, new_inames=None, suffix=None, tags={}): """ :arg within: a stack match as understood by - :func:`loopy.context_matching.parse_stack_match`. + :func:`loopy.match.parse_stack_match`. """ # {{{ normalize arguments, find unique new_inames @@ -724,7 +724,7 @@ def duplicate_inames(knl, inames, within, new_inames=None, suffix=None, if isinstance(new_inames, str): new_inames = [iname.strip() for iname in new_inames.split(",")] - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match within = parse_stack_match(within) if new_inames is None: @@ -802,7 +802,7 @@ def duplicate_inames(knl, inames, within, new_inames=None, suffix=None, def rename_iname(knl, old_iname, new_iname, existing_ok=False, within=None): """ :arg within: a stack match as understood by - :func:`loopy.context_matching.parse_stack_match`. + :func:`loopy.match.parse_stack_match`. :arg existing_ok: execute even if *new_iname* already exists """ @@ -852,7 +852,7 @@ def rename_iname(knl, old_iname, new_iname, existing_ok=False, within=None): from pymbolic import var subst_dict = {old_iname: var(new_iname)} - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match within = parse_stack_match(within) from pymbolic.mapper.substitutor import make_subst_func @@ -972,7 +972,7 @@ def link_inames(knl, inames, new_iname, within=None, tag=None): from pymbolic import var subst_dict = dict((iname, var(new_iname)) for iname in inames) - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match within = parse_stack_match(within) from pymbolic.mapper.substitutor import make_subst_func @@ -1101,7 +1101,7 @@ def _split_reduction(kernel, inames, direction, within=None): inames = inames.split(",") inames = set(inames) - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match within = parse_stack_match(within) rule_mapping_context = SubstitutionRuleMappingContext( @@ -1232,7 +1232,7 @@ def affine_map_inames(kernel, old_inames, new_inames, equations): var_name_gen = kernel.get_var_name_generator() from pymbolic.mapper.substitutor import make_subst_func - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, var_name_gen) @@ -1349,7 +1349,7 @@ def find_unused_axis_tag(kernel, kind, insn_match=None): axis. :arg insn_match: An instruction match as understood by - :func:`loopy.context_matching.parse_match`. + :func:`loopy.match.parse_match`. :arg kind: may be "l" or "g", or the corresponding tag class name :returns: an :class:`GroupIndexTag` or :class:`LocalIndexTag` @@ -1371,7 +1371,7 @@ def find_unused_axis_tag(kernel, kind, insn_match=None): if not found: raise LoopyError("invlaid tag kind: %s" % kind) - from loopy.context_matching import parse_match + from loopy.match import parse_match match = parse_match(insn_match) insns = [insn for insn in kernel.instructions if match(kernel, insn)] @@ -1472,14 +1472,14 @@ def make_reduction_inames_unique(kernel, inames=None, within=None): """ :arg inames: if not *None*, only apply to these inames :arg within: a stack match as understood by - :func:`loopy.context_matching.parse_stack_match`. + :func:`loopy.match.parse_stack_match`. .. versionadded:: 2016.2 """ name_gen = kernel.get_var_name_generator() - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match within = parse_stack_match(within) # {{{ change kernel diff --git a/loopy/transform/instruction.py b/loopy/transform/instruction.py index eefab909f25672860226d402898c442f5fc6e2dd..25634c91975e686fed04223472ee9fbe2b416a4d 100644 --- a/loopy/transform/instruction.py +++ b/loopy/transform/instruction.py @@ -30,13 +30,13 @@ from loopy.diagnostic import LoopyError # {{{ instruction processing def find_instructions(kernel, insn_match): - from loopy.context_matching import parse_match + from loopy.match import parse_match match = parse_match(insn_match) return [insn for insn in kernel.instructions if match(kernel, insn)] def map_instructions(kernel, insn_match, f): - from loopy.context_matching import parse_match + from loopy.match import parse_match match = parse_match(insn_match) new_insns = [] @@ -54,7 +54,7 @@ def set_instruction_priority(kernel, insn_match, priority): """Set the priority of instructions matching *insn_match* to *priority*. *insn_match* may be any instruction id match understood by - :func:`loopy.context_matching.parse_match`. + :func:`loopy.match.parse_match`. """ def set_prio(insn): @@ -68,7 +68,7 @@ def add_dependency(kernel, insn_match, dependency): by *insn_match*. *insn_match* may be any instruction id match understood by - :func:`loopy.context_matching.parse_match`. + :func:`loopy.match.parse_match`. """ if dependency not in kernel.id_to_insn: @@ -130,7 +130,7 @@ def remove_instructions(kernel, insn_ids): # {{{ tag_instructions def tag_instructions(kernel, new_tag, within=None): - from loopy.context_matching import parse_match + from loopy.match import parse_match within = parse_match(within) new_insns = [] diff --git a/loopy/transform/parameter.py b/loopy/transform/parameter.py index f7600b212cbf4db6b58c91bb3603f5a310c6b2a6..fc5dad91dd73245c328f06e2c452b1d3d3a1da2b 100644 --- a/loopy/transform/parameter.py +++ b/loopy/transform/parameter.py @@ -116,7 +116,7 @@ def _fix_parameter(kernel, name, value): for tv in six.itervalues(kernel.temporary_variables): new_temp_vars[tv.name] = tv.map_exprs(map_expr) - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match within = parse_stack_match(None) rule_mapping_context = SubstitutionRuleMappingContext( diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index 2a75e9deaedd8da61f3af1321a88137413f39426..45608ea8e51710e89ac0a024a646b69ea3668a07 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -291,7 +291,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, names/indices to be used as storage axes. May also equivalently be a comma-separated string. :arg within: a stack match as understood by - :func:`loopy.context_matching.parse_stack_match`. + :func:`loopy.match.parse_stack_match`. :arg temporary_name: The temporary variable name to use for storing the precomputed data. If it does not exist, it will be created. If it does exist, its properties @@ -375,7 +375,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, raise ValueError("not all uses in subst_use agree " "on rule name and tag") - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match within = parse_stack_match(within) from loopy.kernel.data import parse_tag @@ -724,7 +724,7 @@ def precompute(kernel, subst_use, sweep_inames=[], within=None, rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match expr_subst_map = RuleAwareSubstitutionMapper( rule_mapping_context, make_subst_func(storage_axis_subst_dict), diff --git a/loopy/transform/subst.py b/loopy/transform/subst.py index e599c902227faf8d1292ece2307d097bc8fd7c19..63e70f8e386bfc6ecdddb4b8661e9f975206725e 100644 --- a/loopy/transform/subst.py +++ b/loopy/transform/subst.py @@ -294,7 +294,7 @@ def assignment_to_subst(kernel, lhs_name, extra_arguments=(), within=None, rule. :arg within: a stack match as understood by - :func:`loopy.context_matching.parse_stack_match`. + :func:`loopy.match.parse_stack_match`. :arg force_retain_argument: If True and if *lhs_name* is an argument, it is kept even if it is no longer referenced. @@ -372,7 +372,7 @@ def assignment_to_subst(kernel, lhs_name, extra_arguments=(), within=None, raise LoopyError("no assignments to variable '%s' found" % lhs_name) - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match within = parse_stack_match(within) rule_mapping_context = SubstitutionRuleMappingContext( @@ -463,7 +463,7 @@ def expand_subst(kernel, within=None): logger.debug("%s: expand subst" % kernel.name) from loopy.symbolic import RuleAwareSubstitutionRuleExpander - from loopy.context_matching import parse_stack_match + from loopy.match import parse_stack_match rule_mapping_context = SubstitutionRuleMappingContext( kernel.substitutions, kernel.get_var_name_generator()) submap = RuleAwareSubstitutionRuleExpander( @@ -483,7 +483,7 @@ def find_rules_matching(knl, pattern): :pattern: A shell-style glob pattern. """ - from loopy.context_matching import re_from_glob + from loopy.match import re_from_glob pattern = re_from_glob(pattern) return [r for r in knl.substitutions if pattern.match(r)]