diff --git a/loopy/match.py b/loopy/match.py index 3c047e463939cd67a4878d202a754c0cab48058d..8ab9fd74382d8ce7fe89d51d8a4f823fe7a398b1 100644 --- a/loopy/match.py +++ b/loopy/match.py @@ -25,13 +25,14 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from six.moves import range, intern +from six.moves import range +from loopy.diagnostic import LoopyError +import ply.lex as lex +import ply.yacc as yacc NoneType = type(None) -from pytools.lex import RE - __doc__ = """ .. autofunction:: parse_match @@ -59,58 +60,6 @@ def re_from_glob(s): return re.compile("^"+translate(s.strip())+"$") -# {{{ parsing - -# {{{ lexer data - -_and = intern("and") -_or = intern("or") -_not = intern("not") -_openpar = intern("openpar") -_closepar = intern("closepar") - -_id = intern("_id") -_tag = intern("_tag") -_writes = intern("_writes") -_reads = intern("_reads") -_iname = intern("_iname") - -_whitespace = intern("_whitespace") - -# }}} - - -_LEX_TABLE = [ - (_and, RE(r"and\b")), - (_or, RE(r"or\b")), - (_not, RE(r"not\b")), - (_openpar, RE(r"\(")), - (_closepar, RE(r"\)")), - - # TERMINALS - (_id, RE(r"id:([\w?*]+)")), - (_tag, RE(r"tag:([\w?*]+)")), - (_writes, RE(r"writes:([\w?*]+)")), - (_reads, RE(r"reads:([\w?*]+)")), - (_iname, RE(r"iname:([\w?*]+)")), - - (_whitespace, RE("[ \t]+")), - ] - - -_TERMINALS = ([_id, _tag, _writes, _reads, _iname]) - -# {{{ operator precedence - -_PREC_OR = 10 -_PREC_AND = 20 -_PREC_NOT = 30 - -# }}} - -# }}} - - # {{{ match expression class MatchExpressionBase(object): @@ -278,94 +227,109 @@ def parse_match(expr): * ``id:yoink and writes:a_temp`` * ``id:yoink and (not writes:a_temp or tag:input)`` """ - if not expr: + + if expr is None: return All() - def parse_terminal(pstate): - next_tag = pstate.next_tag() - if next_tag is _id: - result = Id(pstate.next_match_obj().group(1)) - pstate.advance() - return result - elif next_tag is _tag: - result = Tagged(pstate.next_match_obj().group(1)) - pstate.advance() - return result - elif next_tag is _writes: - result = Writes(pstate.next_match_obj().group(1)) - pstate.advance() - return result - elif next_tag is _reads: - result = Reads(pstate.next_match_obj().group(1)) - pstate.advance() - return result - elif next_tag is _iname: - result = Iname(pstate.next_match_obj().group(1)) - pstate.advance() - return result + if isinstance(expr, MatchExpressionBase): + return expr + + assert isinstance(expr, str) + + _RESERVED_TO_TYPES = { + 'id': Id, + 'iname': Iname, + 'reads': Reads, + 'writes': Writes, + 'tag': Tagged + } + + _BINARY_OPS_TO_TYPES = { + 'and': And, + 'or': Or + } + + reserved = { + 'id': 'ID', + 'iname': 'INAME', + 'reads': 'READS', + 'writes': 'WRITES', + 'tag': 'TAG', + + 'not': 'NOT', + 'or': 'OR', + 'and': 'AND'} + + tokens = ('NAME', 'COLON', 'LPAREN', 'RPAREN', # noqa: F841 + ) + tuple(reserved.values()) + + precedence = ( # noqa: F841 + ('left', 'OR'), + ('left', 'AND'), + ('left', 'NOT'),) + + def t_NAME(t): + r'[a-zA-Z_*][a-zA-Z0-9_*.]*' + t.type = reserved.get(t.value, 'NAME') + return t + + t_COLON = r':' # noqa: F841,N806 + t_LPAREN = r'\(' # noqa: F841,N806 + t_RPAREN = r'\)' # noqa: F841,N806 + t_ignore = ' \t' # noqa: F841 + + def t_error(t): + raise RuntimeError("Illegal character '%s'." % t.value[0]) + + def p_expr_of_binary_ops(p): + '''expression : expression AND expression + | expression OR expression''' + children = [] + if type(p[1]) == _BINARY_OPS_TO_TYPES[p[2]]: + children.extend(p[1].children) else: - pstate.expected("terminal") - - def inner_parse(pstate, min_precedence=0): - pstate.expect_not_end() - - if pstate.is_next(_not): - pstate.advance() - left_query = Not(inner_parse(pstate, _PREC_NOT)) - elif pstate.is_next(_openpar): - pstate.advance() - left_query = inner_parse(pstate) - pstate.expect(_closepar) - pstate.advance() + children.append(p[1]) + if type(p[3]) == _BINARY_OPS_TO_TYPES[p[2]]: + children.extend(p[3].children) else: - left_query = parse_terminal(pstate) + children.append(p[3]) + p[0] = _BINARY_OPS_TO_TYPES[p[2]](tuple(children)) - did_something = True - while did_something: - did_something = False - if pstate.is_at_end(): - return left_query + def p_expr_of_unary_ops(p): + 'expression : NOT expression' + p[0] = Not(p[2]) - next_tag = pstate.next_tag() + def p_expr_of_parens(p): + 'expression : LPAREN expression RPAREN' + p[0] = p[2] - if next_tag is _and and _PREC_AND > min_precedence: - pstate.advance() - left_query = And( - (left_query, inner_parse(pstate, _PREC_AND))) - did_something = True - elif next_tag is _or and _PREC_OR > min_precedence: - pstate.advance() - left_query = Or( - (left_query, inner_parse(pstate, _PREC_OR))) - did_something = True + def p_expr_as_term(p): + 'expression : term' + p[0] = p[1] - return left_query + def p_terminal(p): + 'term : prop COLON NAME' + p[0] = _RESERVED_TO_TYPES[p[1]](p[3]) - if isinstance(expr, MatchExpressionBase): - return expr + def p_prop(p): + '''prop : ID + | INAME + | TAG + | READS + | WRITES''' + p[0] = p[1] + + def p_empty_terminal(p): + 'term :' + p[0] = All() + + def p_error(p): + raise LoopyError("Syntax error at '%s'." % str(p)) + + lexer = lex.lex() # noqa: F841 - from pytools.lex import LexIterator, lex, InvalidTokenError - try: - pstate = LexIterator( - [(tag, s, idx, matchobj) - for (tag, s, idx, matchobj) in lex(_LEX_TABLE, expr, - match_objects=True) - if tag is not _whitespace], expr) - except InvalidTokenError as e: - from loopy.diagnostic import LoopyError - raise LoopyError( - "invalid match expression: '{match_expr}' ({err_type}: {err_str})" - .format( - match_expr=expr, - err_type=type(e).__name__, - err_str=str(e))) - - if pstate.is_at_end(): - pstate.raise_parse_error("unexpected end of input") - - result = inner_parse(pstate) - if not pstate.is_at_end(): - pstate.raise_parse_error("leftover input after completed parse") + parser = yacc.yacc(debug=False) + result = parser.parse(expr) return result diff --git a/test/test_loopy.py b/test/test_loopy.py index 89b4f5e639a031d3f2d4d89b470d2ccf5fb4b848..ef445fa5e151d353d56a87db06b736dc6ba460cf 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2986,6 +2986,18 @@ def test_shape_mismatch_check(ctx_factory): prg(queue, a=a, b=b) +def test_query_lang_parse(): + from loopy.match import (parse_match, Id, And, Writes, Reads, Not, Iname, + Tagged, Or) + expr1 = 'id:yoink and writes:a_temp and reads:b_temp' + expr2 = 'id:yoink and (not iname:i* or tag:input)' + + assert parse_match(expr1) == And( + (Id('yoink'), Writes('a_temp'), Reads('b_temp'))) + assert parse_match(expr2) == And(( + Id('yoink'), Or((Not(Iname('i*')), Tagged('input'))))) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])