diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 58da591e567b5d85f4094d5e41ff5f8c9e912650..8f7c6ae2a7abf2866c1c80ce6df225f11cfbd473 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -200,6 +200,15 @@ class CombineMapper(RecursiveMapper): self.rec(child, *args) for child in expr.parameters) ) + def map_call_with_kwargs(self, expr, *args): + return self.combine( + (self.rec(expr.function, *args),) + + tuple( + self.rec(child, *args) for child in expr.parameters) + + tuple( + self.rec(child, *args) for child in expr.kw_parameters.values()) + ) + def map_subscript(self, expr, *args): return self.combine( [self.rec(expr.aggregate, *args), @@ -309,18 +318,28 @@ class IdentityMapper(Mapper): return expr def map_call(self, expr, *args): - return expr.__class__( + return type(expr)( self.rec(expr.function, *args), tuple(self.rec(child, *args) for child in expr.parameters)) + def map_call_with_kwargs(self, expr, *args): + return type(expr)( + self.rec(expr.function, *args), + tuple(self.rec(child, *args) + for child in expr.parameters), + dict( + (key, self.rec(val, *args)) + for key, val in expr.kw_parameters.iteritems()) + ) + def map_subscript(self, expr, *args): return expr.__class__( self.rec(expr.aggregate, *args), self.rec(expr.index, *args)) def map_lookup(self, expr, *args): - return expr.__class__( + return type(expr)( self.rec(expr.aggregate, *args), expr.name) @@ -469,6 +488,17 @@ class WalkMapper(RecursiveMapper): for child in expr.parameters: self.rec(child, *args) + def map_call_with_kwargs(self, expr, *args): + if not self.visit(expr, *args): + return + + self.rec(expr.function, *args) + for child in expr.parameters: + self.rec(child, *args) + + for child in expr.kw_parameters.values(): + self.rec(child, *args) + def map_subscript(self, expr, *args): if not self.visit(expr, *args): return diff --git a/pymbolic/mapper/evaluator.py b/pymbolic/mapper/evaluator.py index ab98a9d0673f7ad99b508608457bee3a50013ef6..89270e081106c7045b11c34cc2379291971a3475 100644 --- a/pymbolic/mapper/evaluator.py +++ b/pymbolic/mapper/evaluator.py @@ -67,6 +67,14 @@ class EvaluationMapper(RecursiveMapper, CSECachingMapperMixin): def map_call(self, expr): return self.rec(expr.function)(*[self.rec(par) for par in expr.parameters]) + def map_call_with_kwargs(self, expr): + args = [self.rec(par) for par in expr.parameters] + kwargs = dict( + (k, self.rec(v)) + for k, v in expr.kw_parameters.items()) + + return self.rec(expr.function)(*args, **kwargs) + def map_subscript(self, expr): return self.rec(expr.aggregate)[self.rec(expr.index)] diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index db09210c52ab5fe9a7acec66cee143e7d06eb9fb..b0fd24f6b35df628fe82ec3ab7be5ae02940c5a0 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -134,6 +134,16 @@ class StringifyMapper(pymbolic.mapper.Mapper): self.rec(expr.function, PREC_CALL), self.join_rec(", ", expr.parameters, PREC_NONE)) + def map_call_with_kwargs(self, expr, enclosing_prec): + args_strings = ( + tuple(self.rec(ch, PREC_NONE) for ch in expr.parameters) + + + tuple("%s=%s" % (name, self.rec(ch, PREC_NONE)) + for name, ch in expr.kw_parameters.items())) + return self.format("%s(%s)", + self.rec(expr.function, PREC_CALL), + ", ".join(args_strings)) + def map_subscript(self, expr, enclosing_prec): if isinstance(expr.index, tuple): index_str = self.join_rec(", ", expr.index, PREC_NONE) diff --git a/pymbolic/parser.py b/pymbolic/parser.py index c696750a17c5a4a2512813cbebb73ba476dfdcad..fb90aed8916114e0d351634d45016428a348209a 100644 --- a/pymbolic/parser.py +++ b/pymbolic/parser.py @@ -44,6 +44,8 @@ _comma = intern("comma") _dot = intern("dot") _colon = intern("colon") +_assign = intern("assign") + _equal = intern("equal") _notequal = intern("notequal") _less = intern("less") @@ -85,6 +87,7 @@ class Parser: lex_table = [ (_equal, pytools.lex.RE(r"==")), (_notequal, pytools.lex.RE(r"!=")), + (_equal, pytools.lex.RE(r"==")), (_lessequal, pytools.lex.RE(r"\<=")), (_greaterequal, pytools.lex.RE(r"\>=")), @@ -92,6 +95,8 @@ class Parser: (_less, pytools.lex.RE(r"\<")), (_greater, pytools.lex.RE(r"\>")), + (_assign, pytools.lex.RE(r"=")), + (_and, pytools.lex.RE(r"and")), (_or, pytools.lex.RE(r"or")), (_not, pytools.lex.RE(r"not")), @@ -226,17 +231,13 @@ class Parser: if next_tag is _openpar and _PREC_CALL > min_precedence: pstate.advance() - pstate.expect_not_end() - if next_tag is _closepar: - pstate.advance() - left_exp = primitives.Call(left_exp, ()) + args, kwargs = self.parse_arglist(pstate) + + if kwargs: + left_exp = primitives.CallWithKwargs(left_exp, args, kwargs) else: - args = self.parse_expression(pstate) - if not isinstance(args, tuple): - args = (args,) left_exp = primitives.Call(left_exp, args) - pstate.expect(_closepar) - pstate.advance() + did_something = True elif next_tag is _openbracket and _PREC_CALL > min_precedence: pstate.advance() @@ -335,6 +336,49 @@ class Parser: return left_exp, did_something + def parse_arglist(self, pstate): + pstate.expect_not_end() + + args = [] + kwargs = {} + + comma_allowed = False + while True: + pstate.expect_not_end() + + saw_comma = False + if pstate.next_tag() is _comma: + saw_comma = True + if not comma_allowed: + pstate.raise_parse_error("comma not expected") + pstate.advance() + pstate.expect_not_end() + + if pstate.next_tag() is _closepar: + pstate.advance() + return tuple(args), kwargs + + if not saw_comma and comma_allowed: + pstate.raise_parse_error("comma expected") + + if (pstate.next_tag() is _identifier + and not pstate.is_at_end(1) + and pstate.next_tag(1) == _assign): + kw = pstate.next_str() + pstate.advance() + pstate.advance() + + kwargs[kw] = self.parse_expression(pstate, _PREC_COMMA) + else: + if kwargs: + pstate.raise_parse_error( + "positional argument after keyword " + "argument not allowed") + + args.append(self.parse_expression(pstate, _PREC_COMMA)) + + comma_allowed = True + def __call__(self, expr_str): lex_result = [(tag, s, idx) for (tag, s, idx) in pytools.lex.lex(self.lex_table, expr_str) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index b35a9ecd588927c0bbd1c92c893f735b211f58bd..84ce880e2153d135bebbcd06045c9bdbaadd7561 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -55,6 +55,10 @@ Sums, products and such :undoc-members: :members: mapper_method +.. autoclass:: CallWithKwargs + :undoc-members: + :members: mapper_method + .. autoclass:: Subscript :undoc-members: :members: mapper_method @@ -368,8 +372,11 @@ class Expression(object): def __neg__(self): return -1*self - def __call__(self, *pars): - return Call(self, pars) + def __call__(self, *args, **kwargs): + if kwargs: + return CallWithKwargs(self, args, kwargs) + else: + return Call(self, args) def __getitem__(self, subscript): if _SUBSCRIPT_BY_GETITEM: @@ -559,17 +566,19 @@ class Call(AlgebraicLeaf): """A function invocation. .. attribute:: function + + A :class:`Expression` that evaluates to a function. + .. attribute:: parameters + + A :class:`tuple` of positional paramters, each element + of which is a :class:`Expression` or a constant. + """ init_arg_names = ("function", "parameters",) def __init__(self, function, parameters): - """ - :arg function: A :class:`Expression` that evaluates to a function. - :arg parameters: A :class:`tuple` of positional paramters. - """ - self.function = function self.parameters = parameters @@ -589,6 +598,55 @@ class Call(AlgebraicLeaf): mapper_method = intern("map_call") +class CallWithKwargs(AlgebraicLeaf): + """A function invocation with keyword arguments. + + .. attribute:: function + + A :class:`Expression` that evaluates to a function. + + .. attribute:: parameters + + A :class:`tuple` of positional paramters, each element + of which is a :class:`Expression` or a constant. + + .. attribute:: kw_parameters + + A dictionary mapping names to arguments, , each + of which is a :class:`Expression` or a constant, + or an equivalent value accepted by the :class:`dict` + constructor. + """ + + init_arg_names = ("function", "parameters", "kw_parameters") + + def __init__(self, function, parameters, kw_parameters): + self.function = function + self.parameters = parameters + + if isinstance(kw_parameters, dict): + self.kw_parameters = kw_parameters + else: + self.kw_parameters = dict(kw_parameters) + + try: + arg_count = self.function.arg_count + except AttributeError: + pass + else: + if len(self.parameters) != arg_count: + raise TypeError("%s called with wrong number of arguments " + "(need %d, got %d)" % ( + self.function, arg_count, len(parameters))) + + def __getinitargs__(self): + return (self.function, + self.parameters, + tuple(sorted(self.kw_parameters.values()))) + + mapper_method = intern("map_call_with_kwargs") + + class Subscript(AlgebraicLeaf): """An array subscript. diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index ca83b7225f4abbd3be973dc47fd930390eb7b197..e5c32d0d47ddc05ae59432ebaef73a469931c4f6 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -229,9 +229,27 @@ def test_parser(): assert parse("f((x,),z)") == f((x,), z) assert parse("f(x,(y,z),z)") == f(x, (y, z), z) + assert parse("f(x,(y,z),z, name=15)") == f(x, (y, z), z, name=15) + assert parse("f(x,(y,z),z, name=15, name2=17)") == f( + x, (y, z), z, name=15, name2=17) + # }}} +def test_mappers(): + from pymbolic import variables + f, x, y, z = variables("f x y z") + + for expr in [ + f(x, (y, z), name=z**2) + ]: + from pymbolic.mapper import WalkMapper + from pymbolic.mapper.dependency import DependencyMapper + str(expr) + IdentityMapper()(expr) + WalkMapper()(expr) + DependencyMapper()(expr) + # {{{ geometric algebra @pytest.mark.parametrize("dims", [2, 3, 4, 5])