Skip to content
Snippets Groups Projects
Commit 181e05ed authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Add WalkMapper, CallbackMapper.

parent 5da20fbe
No related branches found
No related tags found
No related merge requests found
...@@ -19,6 +19,7 @@ class UnsupportedExpressionError(ValueError): ...@@ -19,6 +19,7 @@ class UnsupportedExpressionError(ValueError):
pass pass
# {{{ mapper base
class Mapper(object): class Mapper(object):
def handle_unsupported_expression(self, expr, *args): def handle_unsupported_expression(self, expr, *args):
...@@ -71,6 +72,8 @@ class Mapper(object): ...@@ -71,6 +72,8 @@ class Mapper(object):
raise ValueError, "%s encountered invalid foreign object: %s" % ( raise ValueError, "%s encountered invalid foreign object: %s" % (
self.__class__, repr(expr)) self.__class__, repr(expr))
# }}}
...@@ -80,6 +83,7 @@ class RecursiveMapper(Mapper): ...@@ -80,6 +83,7 @@ class RecursiveMapper(Mapper):
# {{{ combine mapper
class CombineMapper(RecursiveMapper): class CombineMapper(RecursiveMapper):
def map_call(self, expr, *args): def map_call(self, expr, *args):
...@@ -144,7 +148,9 @@ class CombineMapper(RecursiveMapper): ...@@ -144,7 +148,9 @@ class CombineMapper(RecursiveMapper):
self.rec(expr.else_)]) self.rec(expr.else_)])
# }}}
# {{{ identity mapper
class IdentityMapperBase(object): class IdentityMapperBase(object):
def map_constant(self, expr, *args): def map_constant(self, expr, *args):
...@@ -244,8 +250,111 @@ class IdentityMapper(IdentityMapperBase, RecursiveMapper): ...@@ -244,8 +250,111 @@ class IdentityMapper(IdentityMapperBase, RecursiveMapper):
class NonrecursiveIdentityMapper(IdentityMapperBase, Mapper): class NonrecursiveIdentityMapper(IdentityMapperBase, Mapper):
pass pass
# }}}
# {{{ walk mapper
class WalkMapper(RecursiveMapper):
def map_constant(self, expr, *args):
pass
def map_variable(self, expr, *args):
pass
def map_function_symbol(self, expr, *args):
pass
def map_call(self, expr, *args):
self.rec(expr.function, *args)
for child in expr.parameters:
self.rec(child, *args)
def map_subscript(self, expr, *args):
self.rec(expr.aggregate, *args)
self.rec(expr.index, *args)
def map_lookup(self, expr, *args):
self.rec(expr.aggregate, *args)
def map_negation(self, expr, *args):
self.rec(expr.child, *args)
def map_sum(self, expr, *args):
for child in expr.children:
self.rec(child, *args)
map_product = map_sum
def map_quotient(self, expr, *args):
self.rec(expr.numerator, *args)
self.rec(expr.denominator, *args)
map_floor_div = map_quotient
map_remainder = map_quotient
def map_power(self, expr, *args):
self.rec(expr.base, *args)
self.rec(expr.exponent, *args)
def map_polynomial(self, expr, *args):
self.rec(expr.base, *args)
for exp, coeff in expr.data:
self.rec(coeff, *args)
def map_list(self, expr, *args):
for child in expr:
self.rec(child, *args)
map_tuple = map_list
def map_numpy_array(self, expr):
from pytools import indices_in_shape
for i in indices_in_shape(expr.shape):
self.rec(expr[i])
def map_common_subexpression(self, expr, *args, **kwargs):
self.rec(expr.child)
def map_if_positive(self, expr):
self.rec(expr.criterion)
self.rec(expr.then)
self.rec(expr.else_)
# }}}
# {{{ callback mapper
class CallbackMapper(RecursiveMapper):
def __init__(self, function, fallback_mapper):
self.function = function
self.fallback_mapper = fallback_mapper
fallback_mapper.rec = self.rec
def map_constant(self, expr, *args):
return self.function(expr, self, *args)
map_variable = map_constant
map_function_symbol = map_constant
map_call = map_constant
map_subscript = map_constant
map_lookup = map_constant
map_negation = map_constant
map_sum = map_constant
map_product = map_constant
map_quotient = map_constant
map_floor_div = map_constant
map_remainder = map_constant
map_power = map_constant
map_polynomial = map_constant
map_list = map_constant
map_tuple = map_constant
map_numpy_array = map_constant
map_common_subexpression = map_constant
map_if_positive = map_constant
# }}}
# {{{ cse caching mixin
class CSECachingMapperMixin(object): class CSECachingMapperMixin(object):
def map_common_subexpression(self, expr): def map_common_subexpression(self, expr):
...@@ -260,3 +369,7 @@ class CSECachingMapperMixin(object): ...@@ -260,3 +369,7 @@ class CSECachingMapperMixin(object):
result = self.map_common_subexpression_uncached(expr) result = self.map_common_subexpression_uncached(expr)
ccd[expr] = result ccd[expr] = result
return result return result
# }}}
# vim: foldmethod=marker
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment