diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 3f048664de7b013d36b62e2876a76448a7931839..e5d8194a473fb022c8cb94f7ea4a29a8134fc6bf 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -19,6 +19,7 @@ class UnsupportedExpressionError(ValueError): pass +# {{{ mapper base class Mapper(object): def handle_unsupported_expression(self, expr, *args): @@ -71,6 +72,8 @@ class Mapper(object): raise ValueError, "%s encountered invalid foreign object: %s" % ( self.__class__, repr(expr)) +# }}} + @@ -80,6 +83,7 @@ class RecursiveMapper(Mapper): +# {{{ combine mapper class CombineMapper(RecursiveMapper): def map_call(self, expr, *args): @@ -144,7 +148,9 @@ class CombineMapper(RecursiveMapper): self.rec(expr.else_)]) +# }}} +# {{{ identity mapper class IdentityMapperBase(object): def map_constant(self, expr, *args): @@ -244,8 +250,111 @@ class IdentityMapper(IdentityMapperBase, RecursiveMapper): class NonrecursiveIdentityMapper(IdentityMapperBase, Mapper): pass +# }}} + +# {{{ walk mapper + +class WalkMapper(RecursiveMapper): + def map_constant(self, expr, *args): + pass + + def map_variable(self, expr, *args): + pass + + def map_function_symbol(self, expr, *args): + pass + + def map_call(self, expr, *args): + self.rec(expr.function, *args) + for child in expr.parameters: + self.rec(child, *args) + + def map_subscript(self, expr, *args): + self.rec(expr.aggregate, *args) + self.rec(expr.index, *args) + + def map_lookup(self, expr, *args): + self.rec(expr.aggregate, *args) + + def map_negation(self, expr, *args): + self.rec(expr.child, *args) + + def map_sum(self, expr, *args): + for child in expr.children: + self.rec(child, *args) + + map_product = map_sum + + def map_quotient(self, expr, *args): + self.rec(expr.numerator, *args) + self.rec(expr.denominator, *args) + + map_floor_div = map_quotient + map_remainder = map_quotient + + def map_power(self, expr, *args): + self.rec(expr.base, *args) + self.rec(expr.exponent, *args) + + def map_polynomial(self, expr, *args): + self.rec(expr.base, *args) + for exp, coeff in expr.data: + self.rec(coeff, *args) + def map_list(self, expr, *args): + for child in expr: + self.rec(child, *args) + + map_tuple = map_list + def map_numpy_array(self, expr): + from pytools import indices_in_shape + for i in indices_in_shape(expr.shape): + self.rec(expr[i]) + + def map_common_subexpression(self, expr, *args, **kwargs): + self.rec(expr.child) + + def map_if_positive(self, expr): + self.rec(expr.criterion) + self.rec(expr.then) + self.rec(expr.else_) + +# }}} + +# {{{ callback mapper + +class CallbackMapper(RecursiveMapper): + def __init__(self, function, fallback_mapper): + self.function = function + self.fallback_mapper = fallback_mapper + fallback_mapper.rec = self.rec + + def map_constant(self, expr, *args): + return self.function(expr, self, *args) + + map_variable = map_constant + map_function_symbol = map_constant + map_call = map_constant + map_subscript = map_constant + map_lookup = map_constant + map_negation = map_constant + map_sum = map_constant + map_product = map_constant + map_quotient = map_constant + map_floor_div = map_constant + map_remainder = map_constant + map_power = map_constant + map_polynomial = map_constant + map_list = map_constant + map_tuple = map_constant + map_numpy_array = map_constant + map_common_subexpression = map_constant + map_if_positive = map_constant + +# }}} + +# {{{ cse caching mixin class CSECachingMapperMixin(object): def map_common_subexpression(self, expr): @@ -260,3 +369,7 @@ class CSECachingMapperMixin(object): result = self.map_common_subexpression_uncached(expr) ccd[expr] = result return result + +# }}} + +# vim: foldmethod=marker