diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index 5dee96e75d36eace37efd1ea1fcaa98cfef7d0ec..bfe56799611dbf3dc0d7fd36cbea8b79c5b233dc 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -27,6 +27,9 @@ from pytools import ImmutableRecord, memoize_method from loopy.diagnostic import LoopyError from loopy.tools import Optional from warnings import warn +import re +import ply.lex as lex +import ply.yacc as yacc # {{{ instructions: base class @@ -1531,4 +1534,127 @@ def _check_and_fix_temp_var_type(temp_var_type, stacklevel=2): # }}} +# {{{ parsing instruction tags + +def parse_instruction_tag(tag): + tokens = ('NAME', 'ASSIGN', 'COMMA', 'INSN', # noqa: F841,N806 + 'LKWTUPLE', 'RKWTUPLE', 'LPAREN', 'RPAREN') + + t_NAME = r'[a-zA-Z_][a-zA-Z0-9_]*' # noqa: F841,N806 + t_INSN = (r' [a-zA-Z0-9_={}\s\[\]\.(),]+' # noqa: F841,N806 + ' ') + t_COMMA = r',' # noqa: F841,N806 + t_ASSIGN = r'=' # noqa: F841,N806 + t_LPAREN = r'\(' # noqa: F841,N806 + t_RPAREN = r'\)' # noqa: F841,N806 + t_LKWTUPLE = r'' # noqa: F841,N806 + t_RKWTUPLE = r'' # noqa: F841,N806 + t_ignore = ' \t' # noqa: F841,N806 + + def t_error(t): + raise RuntimeError("Illegal character '%s'." % t.value[0]) + + def p_kwargs(p): + 'kwargs : kwargs COMMA kwarg' + p[0] = p[1] + (p[3],) + + def p_kwargs_as_kwarg(p): + 'kwargs : kwarg' + p[0] = (p[1],) + + def p_kwarg(p): + 'kwarg : NAME ASSIGN kwarg_val' + p[0] = (p[1], p[3]) + + def p_kwarg_val_as_pickled_obj(p): + 'kwarg_val : INSN' + p[0] = deserialize_instruction( + re.match("(.*)", p[1]).group(1)) + + def p_kwarg_val_as_tuple(p): + '''kwarg_val : LKWTUPLE LPAREN names RPAREN RKWTUPLE + | LKWTUPLE LPAREN names COMMA RPAREN RKWTUPLE + | LKWTUPLE LPAREN NAME COMMA RPAREN RKWTUPLE''' + if isinstance(p[3], tuple): + p[0] = p[3] + else: + p[0] = (p[3], ) + + def p_kwarg_val_as_name(p): + 'kwarg_val : NAME' + p[0] = p[1] + + def p_tuple_defn(p): + 'names : names COMMA NAME' + p[0] = p[1] + (p[3],) + + def p_tuple_of_2_elems(p): + 'names : NAME COMMA NAME' + p[0] = (p[1], p[3]) + + def p_error(p): + raise LoopyError("Syntax error at '%s'." % str(p)) + + lexer = lex.lex() # noqa:F841 + parser = yacc.yacc() + assert isinstance(tag, str) + return parser.parse(tag) + +# }}} + + +def serialize_instruction(insn): + + # {{{ get the options + + insn_options = 'id=%s, ' % insn.id + if insn.depends_on: + insn_options += ("dep="+":".join(insn.depends_on)+", ") + if insn.tags: + insn_options += ("tags="+":".join(insn.tags)+", ") + if insn.within_inames: + insn_options += ("inames="+":".join(insn.within_inames)+", ") + if isinstance(insn, MultiAssignmentBase): + if insn.atomicity: + insn_options += "atomic, " + elif isinstance(insn, BarrierInstruction): + insn_options += ("mem_kind=%s, " % insn.mem_kind) + + insn_options = insn_options[:-2] # remove ', ' + + # }}} + + insn_core = "" + + if isinstance(insn, MultiAssignmentBase): + if insn.assignees: + insn_core += "{assignees} = ".format( + assignees=', '.join(str(assignee) for assignee in + insn.assignees)) + insn_core += str(insn.expression) + elif isinstance(insn, BarrierInstruction): + insn_core += "... {kind}barrier".format( + kind=insn.synchronization_kind[0]) + elif isinstance(insn, NoOpInstruction): + insn_core += "... nop" + else: + raise NotImplementedError() + + return "{core} {{{opts}}}".format(core=insn_core, opts=insn_options) + + +def deserialize_instruction(insn): + assert isinstance(insn, str) + from loopy.kernel.creation import parse_instructions + insns, inames_to_dup, substitutions = parse_instructions( + insn.split('\n'), {}) + if substitutions: + raise LoopyError("Received substitutions on parsing" + " {}".format(insn)) + if len(insns) != 1: + raise LoopyError("Received multiple/no instructions on parsing" + " {}".format(insn)) + #FIXME; Any checks on inames_to_dup? + return insns[0] + # vim: foldmethod=marker diff --git a/loopy/match.py b/loopy/match.py index 3c047e463939cd67a4878d202a754c0cab48058d..1f37b974c8f69f76816f1e5d86d8c5a68ba1b1c9 100644 --- a/loopy/match.py +++ b/loopy/match.py @@ -1,4 +1,4 @@ -"""Matching functionality for instruction ids and subsitution +"""Matching functionality for instruction ids and substitution rule invocations stacks.""" from __future__ import division, absolute_import @@ -25,13 +25,19 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from six.moves import range, intern +import six +from six.moves import range +from loopy.diagnostic import LoopyError +from loopy.kernel.instruction import (parse_instruction_tag, + InstructionBase) +import ply.lex as lex +import ply.yacc as yacc +import re +from itertools import permutations NoneType = type(None) -from pytools.lex import RE - __doc__ = """ .. autofunction:: parse_match @@ -59,58 +65,6 @@ def re_from_glob(s): return re.compile("^"+translate(s.strip())+"$") -# {{{ parsing - -# {{{ lexer data - -_and = intern("and") -_or = intern("or") -_not = intern("not") -_openpar = intern("openpar") -_closepar = intern("closepar") - -_id = intern("_id") -_tag = intern("_tag") -_writes = intern("_writes") -_reads = intern("_reads") -_iname = intern("_iname") - -_whitespace = intern("_whitespace") - -# }}} - - -_LEX_TABLE = [ - (_and, RE(r"and\b")), - (_or, RE(r"or\b")), - (_not, RE(r"not\b")), - (_openpar, RE(r"\(")), - (_closepar, RE(r"\)")), - - # TERMINALS - (_id, RE(r"id:([\w?*]+)")), - (_tag, RE(r"tag:([\w?*]+)")), - (_writes, RE(r"writes:([\w?*]+)")), - (_reads, RE(r"reads:([\w?*]+)")), - (_iname, RE(r"iname:([\w?*]+)")), - - (_whitespace, RE("[ \t]+")), - ] - - -_TERMINALS = ([_id, _tag, _writes, _reads, _iname]) - -# {{{ operator precedence - -_PREC_OR = 10 -_PREC_AND = 20 -_PREC_NOT = 30 - -# }}} - -# }}} - - # {{{ match expression class MatchExpressionBase(object): @@ -214,7 +168,6 @@ class GlobMatchExpressionBase(MatchExpressionBase): def __init__(self, glob): self.glob = glob - import re from fnmatch import translate self.re = re.compile("^"+translate(glob.strip())+"$") @@ -267,6 +220,94 @@ class Iname(GlobMatchExpressionBase): return any(self.re.match(name) for name in matchable.within_inames) + +class Call(GlobMatchExpressionBase): + def __init__(self, glob, kwargs): + assert isinstance(glob, str) + if not isinstance(kwargs, dict): + kwargs_as_tuple = kwargs[:] + kwargs = dict(kwargs_as_tuple) + if len(kwargs) != len(kwargs_as_tuple): + raise LoopyError("Call '{0}' has repetition in the given " + "attributes.".format(glob)) + self.kwargs = kwargs + self.glob = glob + from fnmatch import translate + self.re = re.compile("^"+translate(glob.strip())) + + def __call__(self, kernel, matchable): + for tag in matchable.tags: + match = re.match(r'^(?P[\w.\*]+)\((?P.+)\)$', tag) + + if not match: + # no match => skip this tag + continue + + if not self.re.match(match.groupdict()['name']): + # 'name' of the call should fnmatch with the tag call name + continue + + tag_kwargs = dict(parse_instruction_tag(match.groupdict()['kwargs'])) + from fnmatch import translate + + # {{{ sanity checks on the provided kwargs + + if not(self.kwargs.keys() <= tag_kwargs.keys()): + raise LoopyError( + "Unknown attributes '{0}' provided to '{1}'.".format( + self.kwargs.keys()-tag_kwargs.keys(), + self.glob)) + # }}} + + did_match = True + for kw, match_arg in six.iteritems(self.kwargs): + tagged_arg = tag_kwargs[kw] + if isinstance(tagged_arg, InstructionBase): + if not match_arg(kernel, tagged_arg): + did_match = False + break + elif isinstance(tagged_arg, tuple): + if not isinstance(match_arg, tuple): + match_arg = (match_arg,) + + if len(match_arg) != len(tagged_arg): + # FIXME: Should we check for matches if the number of + # the names in the mathchee tuple is less than the + # matcher? + did_match = False + break + + tuple_match = False + # Check for matches between all tuple perms + for tagged_perm, match_perm in zip( + permutations(tagged_arg), permutations(match_arg)): + if all(re.match("^"+translate(match_name), tagged_name) + for tagged_name, match_name in zip( + tagged_perm, match_perm)): + tuple_match = True + break + if not tuple_match: + did_match = False + break + elif isinstance(tagged_arg, str): + if not re.match("^"+translate(match_arg), tagged_arg): + did_match = False + break + else: + raise NotImplementedError("Unknown tagged arg {0}.".format( + type(tagged_arg).__name__)) + + if did_match: + return True + + return False + + def __str__(self): + return "{0}({1})".format( + self.glob, + ", ".join("%s=%s" % (key, val) for key, val in + six.iteritems(self.kwargs))) + # }}} @@ -278,94 +319,158 @@ def parse_match(expr): * ``id:yoink and writes:a_temp`` * ``id:yoink and (not writes:a_temp or tag:input)`` """ - if not expr: - return All() - - def parse_terminal(pstate): - next_tag = pstate.next_tag() - if next_tag is _id: - result = Id(pstate.next_match_obj().group(1)) - pstate.advance() - return result - elif next_tag is _tag: - result = Tagged(pstate.next_match_obj().group(1)) - pstate.advance() - return result - elif next_tag is _writes: - result = Writes(pstate.next_match_obj().group(1)) - pstate.advance() - return result - elif next_tag is _reads: - result = Reads(pstate.next_match_obj().group(1)) - pstate.advance() - return result - elif next_tag is _iname: - result = Iname(pstate.next_match_obj().group(1)) - pstate.advance() - return result - else: - pstate.expected("terminal") - - def inner_parse(pstate, min_precedence=0): - pstate.expect_not_end() - - if pstate.is_next(_not): - pstate.advance() - left_query = Not(inner_parse(pstate, _PREC_NOT)) - elif pstate.is_next(_openpar): - pstate.advance() - left_query = inner_parse(pstate) - pstate.expect(_closepar) - pstate.advance() - else: - left_query = parse_terminal(pstate) - - did_something = True - while did_something: - did_something = False - if pstate.is_at_end(): - return left_query - next_tag = pstate.next_tag() - - if next_tag is _and and _PREC_AND > min_precedence: - pstate.advance() - left_query = And( - (left_query, inner_parse(pstate, _PREC_AND))) - did_something = True - elif next_tag is _or and _PREC_OR > min_precedence: - pstate.advance() - left_query = Or( - (left_query, inner_parse(pstate, _PREC_OR))) - did_something = True - - return left_query + # None -> '' + expr = expr if expr else '' if isinstance(expr, MatchExpressionBase): return expr - from pytools.lex import LexIterator, lex, InvalidTokenError - try: - pstate = LexIterator( - [(tag, s, idx, matchobj) - for (tag, s, idx, matchobj) in lex(_LEX_TABLE, expr, - match_objects=True) - if tag is not _whitespace], expr) - except InvalidTokenError as e: - from loopy.diagnostic import LoopyError - raise LoopyError( - "invalid match expression: '{match_expr}' ({err_type}: {err_str})" - .format( - match_expr=expr, - err_type=type(e).__name__, - err_str=str(e))) - - if pstate.is_at_end(): - pstate.raise_parse_error("unexpected end of input") - - result = inner_parse(pstate) - if not pstate.is_at_end(): - pstate.raise_parse_error("leftover input after completed parse") + _RESERVED_TO_TYPES = { + 'id': Id, + 'iname': Iname, + 'reads': Reads, + 'writes': Writes, + 'tag': Tagged + } + + _BINARY_OPS_TO_TYPES = { + 'and': And, + 'or': Or + } + + reserved = { + 'id': 'ID', + 'iname': 'INAME', + 'reads': 'READS', + 'writes': 'WRITES', + 'tag': 'TAG', + + 'not': 'NOT', + 'or': 'OR', + 'and': 'AND'} + + tokens = ('NAME', 'COLON', 'LPAREN', 'RPAREN', 'COMMA', 'ASSIGN',) + tuple(reserved.values()) # noqa + + precedence = ( # noqa + ('left', 'OR'), + ('left', 'AND'), + ('left', 'NOT'),) + + def t_NAME(t): + r'[a-zA-Z_*][a-zA-Z0-9_*.]*' + t.type = reserved.get(t.value, 'NAME') + return t + + t_COLON = r':' # noqa + t_LPAREN = r'\(' # noqa + t_RPAREN = r'\)' # noqa + t_COMMA = r',' # noqa + t_ASSIGN = r'=' # noqa + t_ignore = ' \t' # noqa + + def t_error(t): + raise RuntimeError("Illegal character '%s'." % t.value[0]) + + def p_expr_of_binary_ops(p): + '''expression : expression AND expression + | expression OR expression''' + children = [] + if type(p[1]) == _BINARY_OPS_TO_TYPES[p[2]]: + children.extend(p[1].children) + else: + children.append(p[1]) + if type(p[3]) == _BINARY_OPS_TO_TYPES[p[2]]: + children.extend(p[3].children) + else: + children.append(p[3]) + p[0] = _BINARY_OPS_TO_TYPES[p[2]](tuple(children)) + + def p_expr_of_unary_ops(p): + 'expression : NOT expression' + p[0] = Not(p[2]) + + def p_parens(p): + '''expression : LPAREN expression RPAREN + parened_name : LPAREN parened_name RPAREN + tuple_of_names : LPAREN tuple_of_names RPAREN''' + p[0] = p[2] + + def p_expr_as_terminal(p): + 'expression : term' + p[0] = p[1] + + def p_terminal_as_property(p): + 'term : prop COLON NAME' + p[0] = _RESERVED_TO_TYPES[p[1]](p[3]) + + def p_prop(p): + '''prop : ID + | INAME + | TAG + | READS + | WRITES''' + p[0] = p[1] + + def p_terminal_as_call(p): + '''term : NAME LPAREN kwargs RPAREN + | NAME LPAREN kwargs COMMA RPAREN''' + p[0] = Call(p[1], p[3]) + + def p_tuple_contents(p): + '''kwargs : kwargs COMMA kwarg + names : names COMMA parened_name''' + p[0] = p[1] + (p[3],) + + def p_tuple_as_element(p): + 'kwargs : kwarg' + p[0] = (p[1],) + + def p_names_tuple_of_len_2(p): + 'names : parened_name COMMA parened_name' + p[0] = (p[1], p[3]) + + def p_kwarg(p): + 'kwarg : NAME ASSIGN kwarg_val' + if not re.match(r'^[a-zA-Z_][a-zA-Z_0-9]*$', p[1]): + raise LoopyError('Invalid Kwarg attribute {0}. Can only contain ' + 'alphanumerics and underscore.'.format(p[1])) + p[0] = (p[1], p[3]) + + def p_kwarg_val(p): + '''kwarg_val : expression + | tuple_of_names + | parened_name''' + p[0] = p[1] + + def p_parened_name(p): + 'parened_name : NAME' + p[0] = p[1] + + def p_tuple_of_names(p): + '''tuple_of_names : LPAREN names RPAREN + | LPAREN names COMMA RPAREN''' + p[0] = p[2] + + def p_tuple_of_name_with_single_elem(p): + 'tuple_of_names : LPAREN parened_name COMMA RPAREN' + p[0] = (p[2], ) + + def p_empty_terminal(p): + 'term : empty' + p[0] = All() + + def p_empty(p): + 'empty :' + pass + + def p_error(p): + raise LoopyError("Syntax error at '%s'." % str(p)) + + lexer = lex.lex() # noqa + + parser = yacc.yacc(debug=False) + result = parser.parse(expr) return result diff --git a/loopy/preprocess.py b/loopy/preprocess.py index 2afcd3db4331d57e1e61c48ba521ebaa296ddbb2..7e093d04ccbbfff6bad60e610ede4248e258a154 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -1032,6 +1032,8 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, init_id = insn_id_gen( "%s_%s_init" % (insn.id, "_".join(expr.inames))) + from loopy.kernel.instruction import serialize_instruction + init_insn = make_assignment( id=init_id, assignees=acc_vars, @@ -1039,6 +1041,11 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True, within_inames_is_final=insn.within_inames_is_final, depends_on=init_insn_depends_on, expression=expr.operation.neutral_element(*arg_dtypes), + tags=frozenset([ + "loopy.reduction.init(insn={0}" + ", inames=({1}))".format( + serialize_instruction(insn), + ','.join(expr.inames)+',')]), predicates=insn.predicates,) generated_insns.append(init_insn) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index f5cf07b0e1d62212ce36edb48f47eb7de7d31451..3d2c602d8434d6987b44cf9d49780de84cf5ce3d 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -204,7 +204,7 @@ class StringifyMapper(StringifyMapperBase): def map_reduction(self, expr, prec): from pymbolic.mapper.stringifier import PREC_NONE - return "%sreduce(%s, [%s], %s)" % ( + return "%sreduce(%s, (%s), %s)" % ( "simul_" if expr.allow_simultaneous else "", expr.operation, ", ".join(expr.inames), self.rec(expr.expr, PREC_NONE)) diff --git a/test/test_loopy.py b/test/test_loopy.py index 89b4f5e639a031d3f2d4d89b470d2ccf5fb4b848..fd896435cd077ed9cea21df407ff64b04667cc83 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2986,6 +2986,23 @@ def test_shape_mismatch_check(ctx_factory): prg(queue, a=a, b=b) +def test_transform_addressing(): + from loopy.match import parse_match + + knl = lp.make_kernel( + "{[i]: 0<=i<10}", + """ + a = sum(i, b[i]) {id=insn} + """) + + knl = lp.realize_reduction(knl) + + within = parse_match("loopy.reduction.init(insn=id:insn, inames=(i,))") + + assert len([print(insn) + for insn in knl.instructions if within(knl, insn)]) == 1 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])