diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index ebbe98072ef4789d9158ef4fa8ef3c098dc2413d..210570ae91d2a1d92b451b3b0b7161184b97cc2b 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -98,7 +98,7 @@ class Mapper(object): attribute. """ - def handle_unsupported_expression(self, expr, *args): + def handle_unsupported_expression(self, expr, *args, **kwargs): """Mapper method that is invoked for :class:`pymbolic.primitives.Expression` subclasses for which a mapper method does not exist in this mapper. @@ -135,35 +135,35 @@ class Mapper(object): rec = __call__ - def map_variable(self, expr, *args): - return self.map_algebraic_leaf(expr, *args) + def map_variable(self, expr, *args, **kwargs): + return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_subscript(self, expr, *args): - return self.map_algebraic_leaf(expr, *args) + def map_subscript(self, expr, *args, **kwargs): + return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_call(self, expr, *args): - return self.map_algebraic_leaf(expr, *args) + def map_call(self, expr, *args, **kwargs): + return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_lookup(self, expr, *args): - return self.map_algebraic_leaf(expr, *args) + def map_lookup(self, expr, *args, **kwargs): + return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_if_positive(self, expr, *args): - return self.map_algebraic_leaf(expr, *args) + def map_if_positive(self, expr, *args, **kwargs): + return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_rational(self, expr, *args): - return self.map_quotient(expr, *args) + def map_rational(self, expr, *args, **kwargs): + return self.map_quotient(expr, *args, **kwargs) - def map_foreign(self, expr, *args): + def map_foreign(self, expr, *args, **kwargs): """Mapper method dispatch for non-:mod:`pymbolic` objects.""" if isinstance(expr, primitives.VALID_CONSTANT_CLASSES): - return self.map_constant(expr, *args) + return self.map_constant(expr, *args, **kwargs) elif isinstance(expr, list): - return self.map_list(expr, *args) + return self.map_list(expr, *args, **kwargs) elif isinstance(expr, tuple): - return self.map_tuple(expr, *args) + return self.map_tuple(expr, *args, **kwargs) elif is_numpy_array(expr): - return self.map_numpy_array(expr, *args) + return self.map_numpy_array(expr, *args, **kwargs) else: raise ValueError( "%s encountered invalid foreign object: %s" % ( @@ -195,65 +195,67 @@ class CombineMapper(RecursiveMapper): :class:`pymbolic.mapper.dependency.DependencyMapper` is another example. """ - def map_call(self, expr, *args): + def map_call(self, expr, *args, **kwargs): return self.combine( - (self.rec(expr.function, *args),) + + (self.rec(expr.function, *args, **kwargs),) + tuple( - self.rec(child, *args) for child in expr.parameters) + self.rec(child, *args, **kwargs) for child in expr.parameters) ) - def map_call_with_kwargs(self, expr, *args): + def map_call_with_kwargs(self, expr, *args, **kwargs): return self.combine( - (self.rec(expr.function, *args),) + (self.rec(expr.function, *args, **kwargs),) + tuple( - self.rec(child, *args) for child in expr.parameters) + self.rec(child, *args, **kwargs) + for child in expr.parameters) + tuple( - self.rec(child, *args) for child in expr.kw_parameters.values()) + self.rec(child, *args, **kwargs) + for child in expr.kw_parameters.values()) ) - def map_subscript(self, expr, *args): + def map_subscript(self, expr, *args, **kwargs): return self.combine( - [self.rec(expr.aggregate, *args), - self.rec(expr.index, *args)]) + [self.rec(expr.aggregate, *args, **kwargs), + self.rec(expr.index, *args, **kwargs)]) - def map_lookup(self, expr, *args): - return self.rec(expr.aggregate, *args) + def map_lookup(self, expr, *args, **kwargs): + return self.rec(expr.aggregate, *args, **kwargs) - def map_sum(self, expr, *args): - return self.combine(self.rec(child, *args) + def map_sum(self, expr, *args, **kwargs): + return self.combine(self.rec(child, *args, **kwargs) for child in expr.children) map_product = map_sum - def map_quotient(self, expr, *args): + def map_quotient(self, expr, *args, **kwargs): return self.combine(( - self.rec(expr.numerator, *args), - self.rec(expr.denominator, *args))) + self.rec(expr.numerator, *args, **kwargs), + self.rec(expr.denominator, *args, **kwargs))) map_floor_div = map_quotient map_remainder = map_quotient - def map_power(self, expr, *args): + def map_power(self, expr, *args, **kwargs): return self.combine(( - self.rec(expr.base, *args), - self.rec(expr.exponent, *args))) + self.rec(expr.base, *args, **kwargs), + self.rec(expr.exponent, *args, **kwargs))) - def map_polynomial(self, expr, *args): + def map_polynomial(self, expr, *args, **kwargs): return self.combine( - (self.rec(expr.base, *args),) + + (self.rec(expr.base, *args, **kwargs),) + tuple( - self.rec(coeff, *args) for exp, coeff in expr.data) + self.rec(coeff, *args, **kwargs) for exp, coeff in expr.data) ) - def map_left_shift(self, expr, *args): + def map_left_shift(self, expr, *args, **kwargs): return self.combine( - self.rec(expr.shiftee, *args), - self.rec(expr.shift, *args)) + self.rec(expr.shiftee, *args, **kwargs), + self.rec(expr.shift, *args, **kwargs)) map_right_shift = map_left_shift - def map_bitwise_not(self, expr, *args): - return self.rec(expr.child, *args) + def map_bitwise_not(self, expr, *args, **kwargs): + return self.rec(expr.child, *args, **kwargs) map_bitwise_or = map_sum map_bitwise_xor = map_sum map_bitwise_and = map_sum @@ -262,27 +264,27 @@ class CombineMapper(RecursiveMapper): map_logical_and = map_sum map_logical_or = map_sum - def map_comparison(self, expr, *args): + def map_comparison(self, expr, *args, **kwargs): return self.combine(( - self.rec(expr.left, *args), - self.rec(expr.right, *args))) + self.rec(expr.left, *args, **kwargs), + self.rec(expr.right, *args, **kwargs))) map_max = map_sum map_min = map_sum - def map_list(self, expr, *args): - return self.combine(self.rec(child, *args) for child in expr) + def map_list(self, expr, *args, **kwargs): + return self.combine(self.rec(child, *args, **kwargs) for child in expr) map_tuple = map_list - def map_numpy_array(self, expr, *args): + def map_numpy_array(self, expr, *args, **kwargs): return self.combine(self.rec(el) for el in expr.flat) - def map_multivector(self, expr, *args): + def map_multivector(self, expr, *args, **kwargs): return self.combine(self.rec(coeff) for bits, coeff in expr.data.iteritems()) - def map_common_subexpression(self, expr, *args): - return self.rec(expr.child, *args) + def map_common_subexpression(self, expr, *args, **kwargs): + return self.rec(expr.child, *args, **kwargs) def map_if_positive(self, expr): return self.combine([ @@ -333,83 +335,83 @@ class IdentityMapper(Mapper): See :ref:`custom-manipulation` for an example of the manipulations that can be implemented this way. """ - def map_constant(self, expr, *args): + def map_constant(self, expr, *args, **kwargs): # leaf -- no need to rebuild return expr - def map_variable(self, expr, *args): + def map_variable(self, expr, *args, **kwargs): # leaf -- no need to rebuild return expr - def map_function_symbol(self, expr, *args): + def map_function_symbol(self, expr, *args, **kwargs): return expr - def map_call(self, expr, *args): + def map_call(self, expr, *args, **kwargs): return type(expr)( - self.rec(expr.function, *args), - tuple(self.rec(child, *args) + self.rec(expr.function, *args, **kwargs), + tuple(self.rec(child, *args, **kwargs) for child in expr.parameters)) - def map_call_with_kwargs(self, expr, *args): + def map_call_with_kwargs(self, expr, *args, **kwargs): return type(expr)( - self.rec(expr.function, *args), - tuple(self.rec(child, *args) + self.rec(expr.function, *args, **kwargs), + tuple(self.rec(child, *args, **kwargs) for child in expr.parameters), dict( - (key, self.rec(val, *args)) + (key, self.rec(val, *args, **kwargs)) for key, val in expr.kw_parameters.iteritems()) ) - def map_subscript(self, expr, *args): + def map_subscript(self, expr, *args, **kwargs): return type(expr)( - self.rec(expr.aggregate, *args), - self.rec(expr.index, *args)) + self.rec(expr.aggregate, *args, **kwargs), + self.rec(expr.index, *args, **kwargs)) - def map_lookup(self, expr, *args): + def map_lookup(self, expr, *args, **kwargs): return type(expr)( - self.rec(expr.aggregate, *args), + self.rec(expr.aggregate, *args, **kwargs), expr.name) - def map_sum(self, expr, *args): + def map_sum(self, expr, *args, **kwargs): from pymbolic.primitives import flattened_sum return flattened_sum(tuple( - self.rec(child, *args) for child in expr.children)) + self.rec(child, *args, **kwargs) for child in expr.children)) - def map_product(self, expr, *args): + def map_product(self, expr, *args, **kwargs): from pymbolic.primitives import flattened_product return flattened_product(tuple( - self.rec(child, *args) for child in expr.children)) + self.rec(child, *args, **kwargs) for child in expr.children)) - def map_quotient(self, expr, *args): - return expr.__class__(self.rec(expr.numerator, *args), - self.rec(expr.denominator, *args)) + def map_quotient(self, expr, *args, **kwargs): + return expr.__class__(self.rec(expr.numerator, *args, **kwargs), + self.rec(expr.denominator, *args, **kwargs)) map_floor_div = map_quotient map_remainder = map_quotient - def map_power(self, expr, *args): - return expr.__class__(self.rec(expr.base, *args), - self.rec(expr.exponent, *args)) + def map_power(self, expr, *args, **kwargs): + return expr.__class__(self.rec(expr.base, *args, **kwargs), + self.rec(expr.exponent, *args, **kwargs)) - def map_polynomial(self, expr, *args): - return expr.__class__(self.rec(expr.base, *args), - ((exp, self.rec(coeff, *args)) + def map_polynomial(self, expr, *args, **kwargs): + return expr.__class__(self.rec(expr.base, *args, **kwargs), + ((exp, self.rec(coeff, *args, **kwargs)) for exp, coeff in expr.data)) - def map_left_shift(self, expr, *args): + def map_left_shift(self, expr, *args, **kwargs): return type(expr)( - self.rec(expr.shiftee, *args), - self.rec(expr.shift, *args)) + self.rec(expr.shiftee, *args, **kwargs), + self.rec(expr.shift, *args, **kwargs)) map_right_shift = map_left_shift - def map_bitwise_not(self, expr, *args): + def map_bitwise_not(self, expr, *args, **kwargs): return type(expr)( - self.rec(expr.child, *args)) + self.rec(expr.child, *args, **kwargs)) - def map_bitwise_or(self, expr, *args): + def map_bitwise_or(self, expr, *args, **kwargs): return type(expr)(tuple( - self.rec(child, *args) for child in expr.children)) + self.rec(child, *args, **kwargs) for child in expr.children)) map_bitwise_xor = map_bitwise_or map_bitwise_and = map_bitwise_or @@ -418,17 +420,17 @@ class IdentityMapper(Mapper): map_logical_or = map_bitwise_or map_logical_and = map_bitwise_or - def map_comparison(self, expr, *args): + def map_comparison(self, expr, *args, **kwargs): return type(expr)( - self.rec(expr.left, *args), + self.rec(expr.left, *args, **kwargs), expr.operator, - self.rec(expr.right, *args)) + self.rec(expr.right, *args, **kwargs)) - def map_list(self, expr, *args): - return [self.rec(child, *args) for child in expr] + def map_list(self, expr, *args, **kwargs): + return [self.rec(child, *args, **kwargs) for child in expr] - def map_tuple(self, expr, *args): - return tuple(self.rec(child, *args) for child in expr) + def map_tuple(self, expr, *args, **kwargs): + return tuple(self.rec(child, *args, **kwargs) for child in expr) def map_numpy_array(self, expr): import numpy @@ -438,8 +440,8 @@ class IdentityMapper(Mapper): result[i] = self.rec(expr[i]) return result - def map_multivector(self, expr, *args): - return expr.map(lambda ch: self.rec(ch, *args)) + def map_multivector(self, expr, *args, **kwargs): + return expr.map(lambda ch: self.rec(ch, *args, **kwargs)) def map_common_subexpression(self, expr, *args, **kwargs): from pymbolic.primitives import is_zero @@ -453,38 +455,38 @@ class IdentityMapper(Mapper): expr.scope, **expr.get_extra_properties()) - def map_substitution(self, expr, *args): + def map_substitution(self, expr, *args, **kwargs): return type(expr)( - self.rec(expr.child, *args), + self.rec(expr.child, *args, **kwargs), expr.variables, - tuple(self.rec(v, *args) for v in expr.values)) + tuple(self.rec(v, *args, **kwargs) for v in expr.values)) - def map_derivative(self, expr, *args): + def map_derivative(self, expr, *args, **kwargs): return type(expr)( - self.rec(expr.child, *args), + self.rec(expr.child, *args, **kwargs), expr.variables) - def map_slice(self, expr, *args): + def map_slice(self, expr, *args, **kwargs): def do_map(expr): if expr is None: return expr else: - return self.rec(expr, *args) + return self.rec(expr, *args, **kwargs) return type(expr)( tuple(do_map(ch) for ch in expr.children)) - def map_if_positive(self, expr, *args): + def map_if_positive(self, expr, *args, **kwargs): return type(expr)( - self.rec(expr.criterion, *args), - self.rec(expr.then, *args), - self.rec(expr.else_, *args)) + self.rec(expr.criterion, *args, **kwargs), + self.rec(expr.then, *args, **kwargs), + self.rec(expr.else_, *args, **kwargs)) - def map_if(self, expr, *args): + def map_if(self, expr, *args, **kwargs): return type(expr)( - self.rec(expr.condition, *args), - self.rec(expr.then, *args), - self.rec(expr.else_, *args)) + self.rec(expr.condition, *args, **kwargs), + self.rec(expr.then, *args, **kwargs), + self.rec(expr.else_, *args, **kwargs)) # }}} @@ -496,99 +498,99 @@ class WalkMapper(RecursiveMapper): without propagating any result. Also calls :meth:`visit` for each visited subexpression. - .. method:: visit(expr, *args) + .. method:: visit(expr, *args, **kwargs) """ - def map_constant(self, expr, *args): - self.visit(expr, *args) + def map_constant(self, expr, *args, **kwargs): + self.visit(expr, *args, **kwargs) - def map_variable(self, expr, *args): - self.visit(expr, *args) + def map_variable(self, expr, *args, **kwargs): + self.visit(expr, *args, **kwargs) - def map_function_symbol(self, expr, *args): + def map_function_symbol(self, expr, *args, **kwargs): self.visit(expr) - def map_call(self, expr, *args): - if not self.visit(expr, *args): + def map_call(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.function, *args) + self.rec(expr.function, *args, **kwargs) for child in expr.parameters: - self.rec(child, *args) + self.rec(child, *args, **kwargs) - def map_call_with_kwargs(self, expr, *args): - if not self.visit(expr, *args): + def map_call_with_kwargs(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.function, *args) + self.rec(expr.function, *args, **kwargs) for child in expr.parameters: - self.rec(child, *args) + self.rec(child, *args, **kwargs) for child in expr.kw_parameters.values(): - self.rec(child, *args) + self.rec(child, *args, **kwargs) - def map_subscript(self, expr, *args): - if not self.visit(expr, *args): + def map_subscript(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.aggregate, *args) - self.rec(expr.index, *args) + self.rec(expr.aggregate, *args, **kwargs) + self.rec(expr.index, *args, **kwargs) - def map_lookup(self, expr, *args): - if not self.visit(expr, *args): + def map_lookup(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.aggregate, *args) + self.rec(expr.aggregate, *args, **kwargs) - def map_sum(self, expr, *args): - if not self.visit(expr, *args): + def map_sum(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return for child in expr.children: - self.rec(child, *args) + self.rec(child, *args, **kwargs) map_product = map_sum - def map_quotient(self, expr, *args): - if not self.visit(expr, *args): + def map_quotient(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.numerator, *args) - self.rec(expr.denominator, *args) + self.rec(expr.numerator, *args, **kwargs) + self.rec(expr.denominator, *args, **kwargs) map_floor_div = map_quotient map_remainder = map_quotient - def map_power(self, expr, *args): - if not self.visit(expr, *args): + def map_power(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.base, *args) - self.rec(expr.exponent, *args) + self.rec(expr.base, *args, **kwargs) + self.rec(expr.exponent, *args, **kwargs) - def map_polynomial(self, expr, *args): - if not self.visit(expr, *args): + def map_polynomial(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.base, *args) + self.rec(expr.base, *args, **kwargs) for exp, coeff in expr.data: - self.rec(coeff, *args) + self.rec(coeff, *args, **kwargs) - def map_list(self, expr, *args): - if not self.visit(expr, *args): + def map_list(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return for child in expr: - self.rec(child, *args) + self.rec(child, *args, **kwargs) map_tuple = map_list - def map_numpy_array(self, expr, *args): - if not self.visit(expr, *args): + def map_numpy_array(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return from pytools import indices_in_shape for i in indices_in_shape(expr.shape): - self.rec(expr[i], *args) + self.rec(expr[i], *args, **kwargs) def map_multivector(self, expr, *args): if not self.visit(expr, *args): @@ -598,72 +600,72 @@ class WalkMapper(RecursiveMapper): self.rec(coeff) def map_common_subexpression(self, expr, *args, **kwargs): - if not self.visit(expr, *args): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.child, *args) + self.rec(expr.child, *args, **kwargs) - def map_left_shift(self, expr, *args): - if not self.visit(expr, *args): + def map_left_shift(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.shift, *args) - self.rec(expr.shiftee, *args) + self.rec(expr.shift, *args, **kwargs) + self.rec(expr.shiftee, *args, **kwargs) map_right_shift = map_left_shift - def map_bitwise_not(self, expr, *args): - if not self.visit(expr, *args): + def map_bitwise_not(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.child, *args) + self.rec(expr.child, *args, **kwargs) map_bitwise_or = map_sum map_bitwise_xor = map_sum map_bitwise_and = map_sum - def map_comparison(self, expr, *args): - if not self.visit(expr, *args): + def map_comparison(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.left, *args) - self.rec(expr.right, *args) + self.rec(expr.left, *args, **kwargs) + self.rec(expr.right, *args, **kwargs) map_logical_not = map_bitwise_not map_logical_and = map_sum map_logical_or = map_sum - def map_if(self, expr, *args): - if not self.visit(expr, *args): + def map_if(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.condition, *args) - self.rec(expr.then, *args) - self.rec(expr.else_, *args) + self.rec(expr.condition, *args, **kwargs) + self.rec(expr.then, *args, **kwargs) + self.rec(expr.else_, *args, **kwargs) - def map_if_positive(self, expr, *args): - if not self.visit(expr, *args): + def map_if_positive(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.criterion, *args) - self.rec(expr.then, *args) - self.rec(expr.else_, *args) + self.rec(expr.criterion, *args, **kwargs) + self.rec(expr.then, *args, **kwargs) + self.rec(expr.else_, *args, **kwargs) - def map_substitution(self, expr, *args): + def map_substitution(self, expr, *args, **kwargs): if not self.visit(expr): return - self.rec(expr.child, *args) + self.rec(expr.child, *args, **kwargs) for v in expr.values: - self.rec(v, *args) + self.rec(v, *args, **kwargs) - def map_derivative(self, expr, *args): - if not self.visit(expr, *args): + def map_derivative(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): return - self.rec(expr.child, *args) + self.rec(expr.child, *args, **kwargs) - def visit(self, expr, *args): + def visit(self, expr, *args, **kwargs): return True # }}} @@ -677,8 +679,8 @@ class CallbackMapper(RecursiveMapper): self.fallback_mapper = fallback_mapper fallback_mapper.rec = self.rec - def map_constant(self, expr, *args): - return self.function(expr, self, *args) + def map_constant(self, expr, *args, **kwargs): + return self.function(expr, self, *args, **kwargs) map_variable = map_constant map_function_symbol = map_constant diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index b0fd24f6b35df628fe82ec3ab7be5ae02940c5a0..4b34a76b4c7bf2803a0a3a9d5a63e7c4570d13e2 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -89,9 +89,9 @@ class StringifyMapper(pymbolic.mapper.Mapper): def join(self, joiner, iterable): return self.format(joiner.join("%s" for i in iterable), *iterable) - def join_rec(self, joiner, iterable, prec): + def join_rec(self, joiner, iterable, prec, *args, **kwargs): f = joiner.join("%s" for i in iterable) - return self.format(f, *[self.rec(i, prec) for i in iterable]) + return self.format(f, *[self.rec(i, prec, *args, **kwargs) for i in iterable]) def parenthesize(self, s): return "(%s)" % s @@ -106,14 +106,14 @@ class StringifyMapper(pymbolic.mapper.Mapper): # {{{ mappings - def handle_unsupported_expression(self, victim, enclosing_prec): + def handle_unsupported_expression(self, victim, enclosing_prec, *args, **kwargs): strifier = victim.stringifier() if isinstance(self, strifier): raise ValueError("stringifier '%s' can't handle '%s'" % (self, victim.__class__)) - return strifier(self.constant_mapper)(victim, enclosing_prec) + return strifier(self.constant_mapper)(victim, enclosing_prec, *args, **kwargs) - def map_constant(self, expr, enclosing_prec): + def map_constant(self, expr, enclosing_prec, *args, **kwargs): result = self.constant_mapper(expr) if not (result.startswith("(") and result.endswith(")")) \ @@ -123,72 +123,73 @@ class StringifyMapper(pymbolic.mapper.Mapper): else: return result - def map_variable(self, expr, enclosing_prec): + def map_variable(self, expr, enclosing_prec, *args, **kwargs): return expr.name - def map_function_symbol(self, expr, enclosing_prec): + def map_function_symbol(self, expr, enclosing_prec, *args, **kwargs): return expr.__class__.__name__ - def map_call(self, expr, enclosing_prec): + def map_call(self, expr, enclosing_prec, *args, **kwargs): return self.format("%s(%s)", - self.rec(expr.function, PREC_CALL), - self.join_rec(", ", expr.parameters, PREC_NONE)) + self.rec(expr.function, PREC_CALL, *args, **kwargs), + self.join_rec(", ", expr.parameters, PREC_NONE, *args, **kwargs)) - def map_call_with_kwargs(self, expr, enclosing_prec): + def map_call_with_kwargs(self, expr, enclosing_prec, *args, **kwargs): args_strings = ( - tuple(self.rec(ch, PREC_NONE) for ch in expr.parameters) + tuple(self.rec(ch, PREC_NONE, *args, **kwargs) + for ch in expr.parameters) + - tuple("%s=%s" % (name, self.rec(ch, PREC_NONE)) + tuple("%s=%s" % (name, self.rec(ch, PREC_NONE, *args, **kwargs)) for name, ch in expr.kw_parameters.items())) return self.format("%s(%s)", - self.rec(expr.function, PREC_CALL), + self.rec(expr.function, PREC_CALL, *args, **kwargs), ", ".join(args_strings)) - def map_subscript(self, expr, enclosing_prec): + def map_subscript(self, expr, enclosing_prec, *args, **kwargs): if isinstance(expr.index, tuple): - index_str = self.join_rec(", ", expr.index, PREC_NONE) + index_str = self.join_rec(", ", expr.index, PREC_NONE, *args, **kwargs) else: - index_str = self.rec(expr.index, PREC_NONE) + index_str = self.rec(expr.index, PREC_NONE, *args, **kwargs) return self.parenthesize_if_needed( self.format("%s[%s]", - self.rec(expr.aggregate, PREC_CALL), + self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), index_str), enclosing_prec, PREC_CALL) - def map_lookup(self, expr, enclosing_prec): + def map_lookup(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( self.format("%s.%s", - self.rec(expr.aggregate, PREC_CALL), + self.rec(expr.aggregate, PREC_CALL, *args, **kwargs), expr.name), enclosing_prec, PREC_CALL) - def map_sum(self, expr, enclosing_prec): + def map_sum(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( - self.join_rec(" + ", expr.children, PREC_SUM), + self.join_rec(" + ", expr.children, PREC_SUM, *args, **kwargs), enclosing_prec, PREC_SUM) - def map_product(self, expr, enclosing_prec): + def map_product(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( - self.join_rec("*", expr.children, PREC_PRODUCT), + self.join_rec("*", expr.children, PREC_PRODUCT, *args, **kwargs), enclosing_prec, PREC_PRODUCT) - def map_quotient(self, expr, enclosing_prec): + def map_quotient(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( self.format("%s / %s", # space is necessary--otherwise '/*' becomes # start-of-comment in C. ('*' from dereference) - self.rec(expr.numerator, PREC_PRODUCT), - self.rec(expr.denominator, PREC_POWER)), # analogous to ^{-1} + self.rec(expr.numerator, PREC_PRODUCT, *args, **kwargs), + self.rec(expr.denominator, PREC_POWER, *args, **kwargs)), # analogous to ^{-1} enclosing_prec, PREC_PRODUCT) - def map_floor_div(self, expr, enclosing_prec): + def map_floor_div(self, expr, enclosing_prec, *args, **kwargs): # (-1) * ((-1)*x // 5) should not reassociate. Therefore raise precedence # on the numerator and shield against surrounding products. result = self.format("%s // %s", - self.rec(expr.numerator, PREC_POWER), - self.rec(expr.denominator, PREC_POWER)) # analogous to ^{-1} + self.rec(expr.numerator, PREC_POWER, *args, **kwargs), + self.rec(expr.denominator, PREC_POWER, *args, **kwargs)) # analogous to ^{-1} # Note ">=", not ">" as in parenthesize_if_needed(). if enclosing_prec >= PREC_PRODUCT: @@ -196,101 +197,101 @@ class StringifyMapper(pymbolic.mapper.Mapper): else: return result - def map_power(self, expr, enclosing_prec): + def map_power(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( self.format("%s**%s", - self.rec(expr.base, PREC_POWER), - self.rec(expr.exponent, PREC_POWER)), + self.rec(expr.base, PREC_POWER, *args, **kwargs), + self.rec(expr.exponent, PREC_POWER, *args, **kwargs)), enclosing_prec, PREC_POWER) - def map_remainder(self, expr, enclosing_prec): + def map_remainder(self, expr, enclosing_prec, *args, **kwargs): return self.format("(%s %% %s)", - self.rec(expr.numerator, PREC_PRODUCT), - self.rec(expr.denominator, PREC_POWER)) # analogous to ^{-1} + self.rec(expr.numerator, PREC_PRODUCT, *args, **kwargs), + self.rec(expr.denominator, PREC_POWER, *args, **kwargs)) # analogous to ^{-1} - def map_polynomial(self, expr, enclosing_prec): + def map_polynomial(self, expr, enclosing_prec, *args, **kwargs): from pymbolic.primitives import flattened_sum return self.rec(flattened_sum( [coeff*expr.base**exp for exp, coeff in expr.data[::-1]]), - enclosing_prec) + enclosing_prec, *args, **kwargs) - def map_left_shift(self, expr, enclosing_prec): + def map_left_shift(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( self.format("%s << %s", - self.rec(expr.shiftee, PREC_SHIFT), - self.rec(expr.shift, PREC_SHIFT)), + self.rec(expr.shiftee, PREC_SHIFT, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT, *args, **kwargs)), enclosing_prec, PREC_SHIFT) - def map_right_shift(self, expr, enclosing_prec): + def map_right_shift(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( self.format("%s >> %s", - self.rec(expr.shiftee, PREC_SHIFT), - self.rec(expr.shift, PREC_SHIFT)), + self.rec(expr.shiftee, PREC_SHIFT, *args, **kwargs), + self.rec(expr.shift, PREC_SHIFT, *args, **kwargs)), enclosing_prec, PREC_SHIFT) - def map_bitwise_not(self, expr, enclosing_prec): + def map_bitwise_not(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( - "~" + self.rec(expr.child, PREC_UNARY), + "~" + self.rec(expr.child, PREC_UNARY, *args, **kwargs), enclosing_prec, PREC_UNARY) - def map_bitwise_or(self, expr, enclosing_prec): + def map_bitwise_or(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( - self.join_rec(" | ", expr.children, PREC_BITWISE_OR), + self.join_rec(" | ", expr.children, PREC_BITWISE_OR, *args, **kwargs), enclosing_prec, PREC_BITWISE_OR) - def map_bitwise_xor(self, expr, enclosing_prec): + def map_bitwise_xor(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( - self.join_rec(" ^ ", expr.children, PREC_BITWISE_XOR), + self.join_rec(" ^ ", expr.children, PREC_BITWISE_XOR, *args, **kwargs), enclosing_prec, PREC_BITWISE_XOR) - def map_bitwise_and(self, expr, enclosing_prec): + def map_bitwise_and(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( - self.join_rec(" ^ ", expr.children, PREC_BITWISE_AND), + self.join_rec(" ^ ", expr.children, PREC_BITWISE_AND, *args, **kwargs), enclosing_prec, PREC_BITWISE_AND) - def map_comparison(self, expr, enclosing_prec): + def map_comparison(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( self.format("%s %s %s", - self.rec(expr.left, PREC_COMPARISON), + self.rec(expr.left, PREC_COMPARISON, *args, **kwargs), expr.operator, - self.rec(expr.right, PREC_COMPARISON)), + self.rec(expr.right, PREC_COMPARISON, *args, **kwargs)), enclosing_prec, PREC_COMPARISON) - def map_logical_not(self, expr, enclosing_prec): + def map_logical_not(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( - "not " + self.rec(expr.child, PREC_UNARY), + "not " + self.rec(expr.child, PREC_UNARY, *args, **kwargs), enclosing_prec, PREC_UNARY) - def map_logical_or(self, expr, enclosing_prec): + def map_logical_or(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( - self.join_rec(" or ", expr.children, PREC_LOGICAL_OR), + self.join_rec(" or ", expr.children, PREC_LOGICAL_OR, *args, **kwargs), enclosing_prec, PREC_LOGICAL_OR) - def map_logical_and(self, expr, enclosing_prec): + def map_logical_and(self, expr, enclosing_prec, *args, **kwargs): return self.parenthesize_if_needed( - self.join_rec(" and ", expr.children, PREC_LOGICAL_AND), + self.join_rec(" and ", expr.children, PREC_LOGICAL_AND, *args, **kwargs), enclosing_prec, PREC_LOGICAL_AND) - def map_list(self, expr, enclosing_prec): - return self.format("[%s]", self.join_rec(", ", expr, PREC_NONE)) + def map_list(self, expr, enclosing_prec, *args, **kwargs): + return self.format("[%s]", self.join_rec(", ", expr, PREC_NONE, *args, **kwargs)) map_vector = map_list - def map_tuple(self, expr, enclosing_prec): - el_str = ", ".join(self.rec(child, PREC_NONE) for child in expr) + def map_tuple(self, expr, enclosing_prec, *args, **kwargs): + el_str = ", ".join(self.rec(child, PREC_NONE, *args, **kwargs) for child in expr) if len(expr) == 1: el_str += "," return "(%s)" % el_str - def map_numpy_array(self, expr, enclosing_prec): + def map_numpy_array(self, expr, enclosing_prec, *args, **kwargs): import numpy from pytools import indices_in_shape str_array = numpy.zeros(expr.shape, dtype="object") max_length = 0 for i in indices_in_shape(expr.shape): - s = self.rec(expr[i], PREC_NONE) + s = self.rec(expr[i], PREC_NONE, *args, **kwargs) max_length = max(len(s), max_length) str_array[i] = s.replace("\n", "\n ") @@ -306,10 +307,10 @@ class StringifyMapper(pymbolic.mapper.Mapper): else: return "array(\n%s)" % "".join(lines) - def map_multivector(self, expr, enclosing_prec): - return expr.stringify(self.rec, enclosing_prec) + def map_multivector(self, expr, enclosing_prec, *args, **kwargs): + return expr.stringify(self.rec, enclosing_prec, *args, **kwargs) - def map_common_subexpression(self, expr, enclosing_prec): + def map_common_subexpression(self, expr, enclosing_prec, *args, **kwargs): from pymbolic.primitives import CommonSubexpression if type(expr) is CommonSubexpression: type_name = "CSE" @@ -317,49 +318,49 @@ class StringifyMapper(pymbolic.mapper.Mapper): type_name = type(expr).__name__ return self.format("%s(%s)", - type_name, self.rec(expr.child, PREC_NONE)) + type_name, self.rec(expr.child, PREC_NONE, *args, **kwargs)) - def map_if(self, expr, enclosing_prec): + def map_if(self, expr, enclosing_prec, *args, **kwargs): return "If(%s, %s, %s)" % ( - self.rec(expr.condition, PREC_NONE), - self.rec(expr.then, PREC_NONE), - self.rec(expr.else_, PREC_NONE)) + self.rec(expr.condition, PREC_NONE, *args, **kwargs), + self.rec(expr.then, PREC_NONE, *args, **kwargs), + self.rec(expr.else_, PREC_NONE, *args, **kwargs)) - def map_if_positive(self, expr, enclosing_prec): + def map_if_positive(self, expr, enclosing_prec, *args, **kwargs): return "If(%s > 0, %s, %s)" % ( - self.rec(expr.criterion, PREC_NONE), - self.rec(expr.then, PREC_NONE), - self.rec(expr.else_, PREC_NONE)) + self.rec(expr.criterion, PREC_NONE, *args, **kwargs), + self.rec(expr.then, PREC_NONE, *args, **kwargs), + self.rec(expr.else_, PREC_NONE, *args, **kwargs)) - def map_min(self, expr, enclosing_prec): + def map_min(self, expr, enclosing_prec, *args, **kwargs): what = type(expr).__name__.lower() return self.format("%s(%s)", - what, self.join_rec(", ", expr.children, PREC_NONE)) + what, self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs)) map_max = map_min - def map_derivative(self, expr, enclosing_prec): + def map_derivative(self, expr, enclosing_prec, *args, **kwargs): derivs = " ".join( "d/d%s" % v for v in expr.variables) return "%s %s" % ( - derivs, self.rec(expr.child, PREC_PRODUCT)) + derivs, self.rec(expr.child, PREC_PRODUCT, *args, **kwargs)) - def map_substitution(self, expr, enclosing_prec): + def map_substitution(self, expr, enclosing_prec, *args, **kwargs): substs = ", ".join( - "%s=%s" % (name, self.rec(val, PREC_NONE)) + "%s=%s" % (name, self.rec(val, PREC_NONE, *args, **kwargs)) for name, val in zip(expr.variables, expr.values)) - return "[%s]{%s}" % (self.rec(expr.child, PREC_NONE), substs) + return "[%s]{%s}" % (self.rec(expr.child, PREC_NONE, *args, **kwargs), substs) - def map_slice(self, expr, enclosing_prec): + def map_slice(self, expr, enclosing_prec, *args, **kwargs): children = [] for child in expr.children: if child is None: children.append("") else: - children.append(self.rec(child, PREC_NONE)) + children.append(self.rec(child, PREC_NONE, *args, **kwargs)) return self.parenthesize_if_needed( self.join(":", children), @@ -367,13 +368,13 @@ class StringifyMapper(pymbolic.mapper.Mapper): # }}} - def __call__(self, expr, prec=PREC_NONE): + def __call__(self, expr, prec=PREC_NONE, *args, **kwargs): """Return a string corresponding to *expr*. If the enclosing precedence level *prec* is higher than *prec* (see :ref:`prec-constants`), parenthesize the result. """ - return pymbolic.mapper.Mapper.__call__(self, expr, prec) + return pymbolic.mapper.Mapper.__call__(self, expr, prec, *args, **kwargs) # }}} @@ -404,7 +405,7 @@ class CSESplittingStringifyMapperMixin(object): See :class:`pymbolic.mapper.c_code.CCodeMapper` for an example of the use of this mix-in. """ - def map_common_subexpression(self, expr, enclosing_prec): + def map_common_subexpression(self, expr, enclosing_prec, *args, **kwargs): try: self.cse_to_name except AttributeError: @@ -415,7 +416,7 @@ class CSESplittingStringifyMapperMixin(object): try: cse_name = self.cse_to_name[expr.child] except KeyError: - str_child = self.rec(expr.child, PREC_NONE) + str_child = self.rec(expr.child, PREC_NONE, *args, **kwargs) if expr.prefix is not None: def generate_cse_names(): @@ -456,15 +457,15 @@ class SortingStringifyMapper(StringifyMapper): StringifyMapper.__init__(self, constant_mapper) self.reverse = reverse - def map_sum(self, expr, enclosing_prec): - entries = [self.rec(i, PREC_SUM) for i in expr.children] + def map_sum(self, expr, enclosing_prec, *args, **kwargs): + entries = [self.rec(i, PREC_SUM, *args, **kwargs) for i in expr.children] entries.sort(reverse=self.reverse) return self.parenthesize_if_needed( self.join(" + ", entries), enclosing_prec, PREC_SUM) - def map_product(self, expr, enclosing_prec): - entries = [self.rec(i, PREC_PRODUCT) for i in expr.children] + def map_product(self, expr, enclosing_prec, *args, **kwargs): + entries = [self.rec(i, PREC_PRODUCT, *args, **kwargs) for i in expr.children] entries.sort(reverse=self.reverse) return self.parenthesize_if_needed( self.join("*", entries), @@ -480,7 +481,7 @@ class SimplifyingSortingStringifyMapper(StringifyMapper): StringifyMapper.__init__(self, constant_mapper) self.reverse = reverse - def map_sum(self, expr, enclosing_prec): + def map_sum(self, expr, enclosing_prec, *args, **kwargs): def get_neg_product(expr): from pymbolic.primitives import is_zero, Product @@ -500,9 +501,9 @@ class SimplifyingSortingStringifyMapper(StringifyMapper): for ch in expr.children: neg_prod = get_neg_product(ch) if neg_prod is not None: - negatives.append(self.rec(neg_prod, PREC_PRODUCT)) + negatives.append(self.rec(neg_prod, PREC_PRODUCT, *args, **kwargs)) else: - positives.append(self.rec(ch, PREC_SUM)) + positives.append(self.rec(ch, PREC_SUM, *args, **kwargs)) positives.sort(reverse=self.reverse) positives = " + ".join(positives) @@ -514,7 +515,7 @@ class SimplifyingSortingStringifyMapper(StringifyMapper): return self.parenthesize_if_needed(result, enclosing_prec, PREC_SUM) - def map_product(self, expr, enclosing_prec): + def map_product(self, expr, enclosing_prec, *args, **kwargs): entries = [] i = 0 from pymbolic.primitives import is_zero @@ -526,10 +527,10 @@ class SimplifyingSortingStringifyMapper(StringifyMapper): # Otherwise two unary minus signs merge into a pre-decrement. entries.append( self.format( - "- %s", self.rec(expr.children[i+1], PREC_UNARY))) + "- %s", self.rec(expr.children[i+1], PREC_UNARY, *args, **kwargs))) i += 2 else: - entries.append(self.rec(child, PREC_PRODUCT)) + entries.append(self.rec(child, PREC_PRODUCT, *args, **kwargs)) i += 1 entries.sort(reverse=self.reverse)