diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py
index 5dee96e75d36eace37efd1ea1fcaa98cfef7d0ec..bfe56799611dbf3dc0d7fd36cbea8b79c5b233dc 100644
--- a/loopy/kernel/instruction.py
+++ b/loopy/kernel/instruction.py
@@ -27,6 +27,9 @@ from pytools import ImmutableRecord, memoize_method
from loopy.diagnostic import LoopyError
from loopy.tools import Optional
from warnings import warn
+import re
+import ply.lex as lex
+import ply.yacc as yacc
# {{{ instructions: base class
@@ -1531,4 +1534,127 @@ def _check_and_fix_temp_var_type(temp_var_type, stacklevel=2):
# }}}
+# {{{ parsing instruction tags
+
+def parse_instruction_tag(tag):
+ tokens = ('NAME', 'ASSIGN', 'COMMA', 'INSN', # noqa: F841,N806
+ 'LKWTUPLE', 'RKWTUPLE', 'LPAREN', 'RPAREN')
+
+ t_NAME = r'[a-zA-Z_][a-zA-Z0-9_]*' # noqa: F841,N806
+ t_INSN = (r' [a-zA-Z0-9_={}\s\[\]\.(),]+' # noqa: F841,N806
+ ' ')
+ t_COMMA = r',' # noqa: F841,N806
+ t_ASSIGN = r'=' # noqa: F841,N806
+ t_LPAREN = r'\(' # noqa: F841,N806
+ t_RPAREN = r'\)' # noqa: F841,N806
+ t_LKWTUPLE = r'' # noqa: F841,N806
+ t_RKWTUPLE = r'' # noqa: F841,N806
+ t_ignore = ' \t' # noqa: F841,N806
+
+ def t_error(t):
+ raise RuntimeError("Illegal character '%s'." % t.value[0])
+
+ def p_kwargs(p):
+ 'kwargs : kwargs COMMA kwarg'
+ p[0] = p[1] + (p[3],)
+
+ def p_kwargs_as_kwarg(p):
+ 'kwargs : kwarg'
+ p[0] = (p[1],)
+
+ def p_kwarg(p):
+ 'kwarg : NAME ASSIGN kwarg_val'
+ p[0] = (p[1], p[3])
+
+ def p_kwarg_val_as_pickled_obj(p):
+ 'kwarg_val : INSN'
+ p[0] = deserialize_instruction(
+ re.match("(.*)", p[1]).group(1))
+
+ def p_kwarg_val_as_tuple(p):
+ '''kwarg_val : LKWTUPLE LPAREN names RPAREN RKWTUPLE
+ | LKWTUPLE LPAREN names COMMA RPAREN RKWTUPLE
+ | LKWTUPLE LPAREN NAME COMMA RPAREN RKWTUPLE'''
+ if isinstance(p[3], tuple):
+ p[0] = p[3]
+ else:
+ p[0] = (p[3], )
+
+ def p_kwarg_val_as_name(p):
+ 'kwarg_val : NAME'
+ p[0] = p[1]
+
+ def p_tuple_defn(p):
+ 'names : names COMMA NAME'
+ p[0] = p[1] + (p[3],)
+
+ def p_tuple_of_2_elems(p):
+ 'names : NAME COMMA NAME'
+ p[0] = (p[1], p[3])
+
+ def p_error(p):
+ raise LoopyError("Syntax error at '%s'." % str(p))
+
+ lexer = lex.lex() # noqa:F841
+ parser = yacc.yacc()
+ assert isinstance(tag, str)
+ return parser.parse(tag)
+
+# }}}
+
+
+def serialize_instruction(insn):
+
+ # {{{ get the options
+
+ insn_options = 'id=%s, ' % insn.id
+ if insn.depends_on:
+ insn_options += ("dep="+":".join(insn.depends_on)+", ")
+ if insn.tags:
+ insn_options += ("tags="+":".join(insn.tags)+", ")
+ if insn.within_inames:
+ insn_options += ("inames="+":".join(insn.within_inames)+", ")
+ if isinstance(insn, MultiAssignmentBase):
+ if insn.atomicity:
+ insn_options += "atomic, "
+ elif isinstance(insn, BarrierInstruction):
+ insn_options += ("mem_kind=%s, " % insn.mem_kind)
+
+ insn_options = insn_options[:-2] # remove ', '
+
+ # }}}
+
+ insn_core = ""
+
+ if isinstance(insn, MultiAssignmentBase):
+ if insn.assignees:
+ insn_core += "{assignees} = ".format(
+ assignees=', '.join(str(assignee) for assignee in
+ insn.assignees))
+ insn_core += str(insn.expression)
+ elif isinstance(insn, BarrierInstruction):
+ insn_core += "... {kind}barrier".format(
+ kind=insn.synchronization_kind[0])
+ elif isinstance(insn, NoOpInstruction):
+ insn_core += "... nop"
+ else:
+ raise NotImplementedError()
+
+ return "{core} {{{opts}}}".format(core=insn_core, opts=insn_options)
+
+
+def deserialize_instruction(insn):
+ assert isinstance(insn, str)
+ from loopy.kernel.creation import parse_instructions
+ insns, inames_to_dup, substitutions = parse_instructions(
+ insn.split('\n'), {})
+ if substitutions:
+ raise LoopyError("Received substitutions on parsing"
+ " {}".format(insn))
+ if len(insns) != 1:
+ raise LoopyError("Received multiple/no instructions on parsing"
+ " {}".format(insn))
+ #FIXME; Any checks on inames_to_dup?
+ return insns[0]
+
# vim: foldmethod=marker
diff --git a/loopy/match.py b/loopy/match.py
index 3c047e463939cd67a4878d202a754c0cab48058d..1f37b974c8f69f76816f1e5d86d8c5a68ba1b1c9 100644
--- a/loopy/match.py
+++ b/loopy/match.py
@@ -1,4 +1,4 @@
-"""Matching functionality for instruction ids and subsitution
+"""Matching functionality for instruction ids and substitution
rule invocations stacks."""
from __future__ import division, absolute_import
@@ -25,13 +25,19 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
-from six.moves import range, intern
+import six
+from six.moves import range
+from loopy.diagnostic import LoopyError
+from loopy.kernel.instruction import (parse_instruction_tag,
+ InstructionBase)
+import ply.lex as lex
+import ply.yacc as yacc
+import re
+from itertools import permutations
NoneType = type(None)
-from pytools.lex import RE
-
__doc__ = """
.. autofunction:: parse_match
@@ -59,58 +65,6 @@ def re_from_glob(s):
return re.compile("^"+translate(s.strip())+"$")
-# {{{ parsing
-
-# {{{ lexer data
-
-_and = intern("and")
-_or = intern("or")
-_not = intern("not")
-_openpar = intern("openpar")
-_closepar = intern("closepar")
-
-_id = intern("_id")
-_tag = intern("_tag")
-_writes = intern("_writes")
-_reads = intern("_reads")
-_iname = intern("_iname")
-
-_whitespace = intern("_whitespace")
-
-# }}}
-
-
-_LEX_TABLE = [
- (_and, RE(r"and\b")),
- (_or, RE(r"or\b")),
- (_not, RE(r"not\b")),
- (_openpar, RE(r"\(")),
- (_closepar, RE(r"\)")),
-
- # TERMINALS
- (_id, RE(r"id:([\w?*]+)")),
- (_tag, RE(r"tag:([\w?*]+)")),
- (_writes, RE(r"writes:([\w?*]+)")),
- (_reads, RE(r"reads:([\w?*]+)")),
- (_iname, RE(r"iname:([\w?*]+)")),
-
- (_whitespace, RE("[ \t]+")),
- ]
-
-
-_TERMINALS = ([_id, _tag, _writes, _reads, _iname])
-
-# {{{ operator precedence
-
-_PREC_OR = 10
-_PREC_AND = 20
-_PREC_NOT = 30
-
-# }}}
-
-# }}}
-
-
# {{{ match expression
class MatchExpressionBase(object):
@@ -214,7 +168,6 @@ class GlobMatchExpressionBase(MatchExpressionBase):
def __init__(self, glob):
self.glob = glob
- import re
from fnmatch import translate
self.re = re.compile("^"+translate(glob.strip())+"$")
@@ -267,6 +220,94 @@ class Iname(GlobMatchExpressionBase):
return any(self.re.match(name)
for name in matchable.within_inames)
+
+class Call(GlobMatchExpressionBase):
+ def __init__(self, glob, kwargs):
+ assert isinstance(glob, str)
+ if not isinstance(kwargs, dict):
+ kwargs_as_tuple = kwargs[:]
+ kwargs = dict(kwargs_as_tuple)
+ if len(kwargs) != len(kwargs_as_tuple):
+ raise LoopyError("Call '{0}' has repetition in the given "
+ "attributes.".format(glob))
+ self.kwargs = kwargs
+ self.glob = glob
+ from fnmatch import translate
+ self.re = re.compile("^"+translate(glob.strip()))
+
+ def __call__(self, kernel, matchable):
+ for tag in matchable.tags:
+ match = re.match(r'^(?P[\w.\*]+)\((?P.+)\)$', tag)
+
+ if not match:
+ # no match => skip this tag
+ continue
+
+ if not self.re.match(match.groupdict()['name']):
+ # 'name' of the call should fnmatch with the tag call name
+ continue
+
+ tag_kwargs = dict(parse_instruction_tag(match.groupdict()['kwargs']))
+ from fnmatch import translate
+
+ # {{{ sanity checks on the provided kwargs
+
+ if not(self.kwargs.keys() <= tag_kwargs.keys()):
+ raise LoopyError(
+ "Unknown attributes '{0}' provided to '{1}'.".format(
+ self.kwargs.keys()-tag_kwargs.keys(),
+ self.glob))
+ # }}}
+
+ did_match = True
+ for kw, match_arg in six.iteritems(self.kwargs):
+ tagged_arg = tag_kwargs[kw]
+ if isinstance(tagged_arg, InstructionBase):
+ if not match_arg(kernel, tagged_arg):
+ did_match = False
+ break
+ elif isinstance(tagged_arg, tuple):
+ if not isinstance(match_arg, tuple):
+ match_arg = (match_arg,)
+
+ if len(match_arg) != len(tagged_arg):
+ # FIXME: Should we check for matches if the number of
+ # the names in the mathchee tuple is less than the
+ # matcher?
+ did_match = False
+ break
+
+ tuple_match = False
+ # Check for matches between all tuple perms
+ for tagged_perm, match_perm in zip(
+ permutations(tagged_arg), permutations(match_arg)):
+ if all(re.match("^"+translate(match_name), tagged_name)
+ for tagged_name, match_name in zip(
+ tagged_perm, match_perm)):
+ tuple_match = True
+ break
+ if not tuple_match:
+ did_match = False
+ break
+ elif isinstance(tagged_arg, str):
+ if not re.match("^"+translate(match_arg), tagged_arg):
+ did_match = False
+ break
+ else:
+ raise NotImplementedError("Unknown tagged arg {0}.".format(
+ type(tagged_arg).__name__))
+
+ if did_match:
+ return True
+
+ return False
+
+ def __str__(self):
+ return "{0}({1})".format(
+ self.glob,
+ ", ".join("%s=%s" % (key, val) for key, val in
+ six.iteritems(self.kwargs)))
+
# }}}
@@ -278,94 +319,158 @@ def parse_match(expr):
* ``id:yoink and writes:a_temp``
* ``id:yoink and (not writes:a_temp or tag:input)``
"""
- if not expr:
- return All()
-
- def parse_terminal(pstate):
- next_tag = pstate.next_tag()
- if next_tag is _id:
- result = Id(pstate.next_match_obj().group(1))
- pstate.advance()
- return result
- elif next_tag is _tag:
- result = Tagged(pstate.next_match_obj().group(1))
- pstate.advance()
- return result
- elif next_tag is _writes:
- result = Writes(pstate.next_match_obj().group(1))
- pstate.advance()
- return result
- elif next_tag is _reads:
- result = Reads(pstate.next_match_obj().group(1))
- pstate.advance()
- return result
- elif next_tag is _iname:
- result = Iname(pstate.next_match_obj().group(1))
- pstate.advance()
- return result
- else:
- pstate.expected("terminal")
-
- def inner_parse(pstate, min_precedence=0):
- pstate.expect_not_end()
-
- if pstate.is_next(_not):
- pstate.advance()
- left_query = Not(inner_parse(pstate, _PREC_NOT))
- elif pstate.is_next(_openpar):
- pstate.advance()
- left_query = inner_parse(pstate)
- pstate.expect(_closepar)
- pstate.advance()
- else:
- left_query = parse_terminal(pstate)
-
- did_something = True
- while did_something:
- did_something = False
- if pstate.is_at_end():
- return left_query
- next_tag = pstate.next_tag()
-
- if next_tag is _and and _PREC_AND > min_precedence:
- pstate.advance()
- left_query = And(
- (left_query, inner_parse(pstate, _PREC_AND)))
- did_something = True
- elif next_tag is _or and _PREC_OR > min_precedence:
- pstate.advance()
- left_query = Or(
- (left_query, inner_parse(pstate, _PREC_OR)))
- did_something = True
-
- return left_query
+ # None -> ''
+ expr = expr if expr else ''
if isinstance(expr, MatchExpressionBase):
return expr
- from pytools.lex import LexIterator, lex, InvalidTokenError
- try:
- pstate = LexIterator(
- [(tag, s, idx, matchobj)
- for (tag, s, idx, matchobj) in lex(_LEX_TABLE, expr,
- match_objects=True)
- if tag is not _whitespace], expr)
- except InvalidTokenError as e:
- from loopy.diagnostic import LoopyError
- raise LoopyError(
- "invalid match expression: '{match_expr}' ({err_type}: {err_str})"
- .format(
- match_expr=expr,
- err_type=type(e).__name__,
- err_str=str(e)))
-
- if pstate.is_at_end():
- pstate.raise_parse_error("unexpected end of input")
-
- result = inner_parse(pstate)
- if not pstate.is_at_end():
- pstate.raise_parse_error("leftover input after completed parse")
+ _RESERVED_TO_TYPES = {
+ 'id': Id,
+ 'iname': Iname,
+ 'reads': Reads,
+ 'writes': Writes,
+ 'tag': Tagged
+ }
+
+ _BINARY_OPS_TO_TYPES = {
+ 'and': And,
+ 'or': Or
+ }
+
+ reserved = {
+ 'id': 'ID',
+ 'iname': 'INAME',
+ 'reads': 'READS',
+ 'writes': 'WRITES',
+ 'tag': 'TAG',
+
+ 'not': 'NOT',
+ 'or': 'OR',
+ 'and': 'AND'}
+
+ tokens = ('NAME', 'COLON', 'LPAREN', 'RPAREN', 'COMMA', 'ASSIGN',) + tuple(reserved.values()) # noqa
+
+ precedence = ( # noqa
+ ('left', 'OR'),
+ ('left', 'AND'),
+ ('left', 'NOT'),)
+
+ def t_NAME(t):
+ r'[a-zA-Z_*][a-zA-Z0-9_*.]*'
+ t.type = reserved.get(t.value, 'NAME')
+ return t
+
+ t_COLON = r':' # noqa
+ t_LPAREN = r'\(' # noqa
+ t_RPAREN = r'\)' # noqa
+ t_COMMA = r',' # noqa
+ t_ASSIGN = r'=' # noqa
+ t_ignore = ' \t' # noqa
+
+ def t_error(t):
+ raise RuntimeError("Illegal character '%s'." % t.value[0])
+
+ def p_expr_of_binary_ops(p):
+ '''expression : expression AND expression
+ | expression OR expression'''
+ children = []
+ if type(p[1]) == _BINARY_OPS_TO_TYPES[p[2]]:
+ children.extend(p[1].children)
+ else:
+ children.append(p[1])
+ if type(p[3]) == _BINARY_OPS_TO_TYPES[p[2]]:
+ children.extend(p[3].children)
+ else:
+ children.append(p[3])
+ p[0] = _BINARY_OPS_TO_TYPES[p[2]](tuple(children))
+
+ def p_expr_of_unary_ops(p):
+ 'expression : NOT expression'
+ p[0] = Not(p[2])
+
+ def p_parens(p):
+ '''expression : LPAREN expression RPAREN
+ parened_name : LPAREN parened_name RPAREN
+ tuple_of_names : LPAREN tuple_of_names RPAREN'''
+ p[0] = p[2]
+
+ def p_expr_as_terminal(p):
+ 'expression : term'
+ p[0] = p[1]
+
+ def p_terminal_as_property(p):
+ 'term : prop COLON NAME'
+ p[0] = _RESERVED_TO_TYPES[p[1]](p[3])
+
+ def p_prop(p):
+ '''prop : ID
+ | INAME
+ | TAG
+ | READS
+ | WRITES'''
+ p[0] = p[1]
+
+ def p_terminal_as_call(p):
+ '''term : NAME LPAREN kwargs RPAREN
+ | NAME LPAREN kwargs COMMA RPAREN'''
+ p[0] = Call(p[1], p[3])
+
+ def p_tuple_contents(p):
+ '''kwargs : kwargs COMMA kwarg
+ names : names COMMA parened_name'''
+ p[0] = p[1] + (p[3],)
+
+ def p_tuple_as_element(p):
+ 'kwargs : kwarg'
+ p[0] = (p[1],)
+
+ def p_names_tuple_of_len_2(p):
+ 'names : parened_name COMMA parened_name'
+ p[0] = (p[1], p[3])
+
+ def p_kwarg(p):
+ 'kwarg : NAME ASSIGN kwarg_val'
+ if not re.match(r'^[a-zA-Z_][a-zA-Z_0-9]*$', p[1]):
+ raise LoopyError('Invalid Kwarg attribute {0}. Can only contain '
+ 'alphanumerics and underscore.'.format(p[1]))
+ p[0] = (p[1], p[3])
+
+ def p_kwarg_val(p):
+ '''kwarg_val : expression
+ | tuple_of_names
+ | parened_name'''
+ p[0] = p[1]
+
+ def p_parened_name(p):
+ 'parened_name : NAME'
+ p[0] = p[1]
+
+ def p_tuple_of_names(p):
+ '''tuple_of_names : LPAREN names RPAREN
+ | LPAREN names COMMA RPAREN'''
+ p[0] = p[2]
+
+ def p_tuple_of_name_with_single_elem(p):
+ 'tuple_of_names : LPAREN parened_name COMMA RPAREN'
+ p[0] = (p[2], )
+
+ def p_empty_terminal(p):
+ 'term : empty'
+ p[0] = All()
+
+ def p_empty(p):
+ 'empty :'
+ pass
+
+ def p_error(p):
+ raise LoopyError("Syntax error at '%s'." % str(p))
+
+ lexer = lex.lex() # noqa
+
+ parser = yacc.yacc(debug=False)
+ result = parser.parse(expr)
return result
diff --git a/loopy/preprocess.py b/loopy/preprocess.py
index 2afcd3db4331d57e1e61c48ba521ebaa296ddbb2..7e093d04ccbbfff6bad60e610ede4248e258a154 100644
--- a/loopy/preprocess.py
+++ b/loopy/preprocess.py
@@ -1032,6 +1032,8 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
init_id = insn_id_gen(
"%s_%s_init" % (insn.id, "_".join(expr.inames)))
+ from loopy.kernel.instruction import serialize_instruction
+
init_insn = make_assignment(
id=init_id,
assignees=acc_vars,
@@ -1039,6 +1041,11 @@ def realize_reduction(kernel, insn_id_filter=None, unknown_types_ok=True,
within_inames_is_final=insn.within_inames_is_final,
depends_on=init_insn_depends_on,
expression=expr.operation.neutral_element(*arg_dtypes),
+ tags=frozenset([
+ "loopy.reduction.init(insn={0}"
+ ", inames=({1}))".format(
+ serialize_instruction(insn),
+ ','.join(expr.inames)+',')]),
predicates=insn.predicates,)
generated_insns.append(init_insn)
diff --git a/loopy/symbolic.py b/loopy/symbolic.py
index f5cf07b0e1d62212ce36edb48f47eb7de7d31451..3d2c602d8434d6987b44cf9d49780de84cf5ce3d 100644
--- a/loopy/symbolic.py
+++ b/loopy/symbolic.py
@@ -204,7 +204,7 @@ class StringifyMapper(StringifyMapperBase):
def map_reduction(self, expr, prec):
from pymbolic.mapper.stringifier import PREC_NONE
- return "%sreduce(%s, [%s], %s)" % (
+ return "%sreduce(%s, (%s), %s)" % (
"simul_" if expr.allow_simultaneous else "",
expr.operation, ", ".join(expr.inames),
self.rec(expr.expr, PREC_NONE))
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 89b4f5e639a031d3f2d4d89b470d2ccf5fb4b848..fd896435cd077ed9cea21df407ff64b04667cc83 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2986,6 +2986,23 @@ def test_shape_mismatch_check(ctx_factory):
prg(queue, a=a, b=b)
+def test_transform_addressing():
+ from loopy.match import parse_match
+
+ knl = lp.make_kernel(
+ "{[i]: 0<=i<10}",
+ """
+ a = sum(i, b[i]) {id=insn}
+ """)
+
+ knl = lp.realize_reduction(knl)
+
+ within = parse_match("loopy.reduction.init(insn=id:insn, inames=(i,))")
+
+ assert len([print(insn)
+ for insn in knl.instructions if within(knl, insn)]) == 1
+
+
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])