Skip to content
Snippets Groups Projects
Commit 181e05ed authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Add WalkMapper, CallbackMapper.

parent 5da20fbe
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,7 @@ class UnsupportedExpressionError(ValueError):
pass
# {{{ mapper base
class Mapper(object):
def handle_unsupported_expression(self, expr, *args):
......@@ -71,6 +72,8 @@ class Mapper(object):
raise ValueError, "%s encountered invalid foreign object: %s" % (
self.__class__, repr(expr))
# }}}
......@@ -80,6 +83,7 @@ class RecursiveMapper(Mapper):
# {{{ combine mapper
class CombineMapper(RecursiveMapper):
def map_call(self, expr, *args):
......@@ -144,7 +148,9 @@ class CombineMapper(RecursiveMapper):
self.rec(expr.else_)])
# }}}
# {{{ identity mapper
class IdentityMapperBase(object):
def map_constant(self, expr, *args):
......@@ -244,8 +250,111 @@ class IdentityMapper(IdentityMapperBase, RecursiveMapper):
class NonrecursiveIdentityMapper(IdentityMapperBase, Mapper):
pass
# }}}
# {{{ walk mapper
class WalkMapper(RecursiveMapper):
def map_constant(self, expr, *args):
pass
def map_variable(self, expr, *args):
pass
def map_function_symbol(self, expr, *args):
pass
def map_call(self, expr, *args):
self.rec(expr.function, *args)
for child in expr.parameters:
self.rec(child, *args)
def map_subscript(self, expr, *args):
self.rec(expr.aggregate, *args)
self.rec(expr.index, *args)
def map_lookup(self, expr, *args):
self.rec(expr.aggregate, *args)
def map_negation(self, expr, *args):
self.rec(expr.child, *args)
def map_sum(self, expr, *args):
for child in expr.children:
self.rec(child, *args)
map_product = map_sum
def map_quotient(self, expr, *args):
self.rec(expr.numerator, *args)
self.rec(expr.denominator, *args)
map_floor_div = map_quotient
map_remainder = map_quotient
def map_power(self, expr, *args):
self.rec(expr.base, *args)
self.rec(expr.exponent, *args)
def map_polynomial(self, expr, *args):
self.rec(expr.base, *args)
for exp, coeff in expr.data:
self.rec(coeff, *args)
def map_list(self, expr, *args):
for child in expr:
self.rec(child, *args)
map_tuple = map_list
def map_numpy_array(self, expr):
from pytools import indices_in_shape
for i in indices_in_shape(expr.shape):
self.rec(expr[i])
def map_common_subexpression(self, expr, *args, **kwargs):
self.rec(expr.child)
def map_if_positive(self, expr):
self.rec(expr.criterion)
self.rec(expr.then)
self.rec(expr.else_)
# }}}
# {{{ callback mapper
class CallbackMapper(RecursiveMapper):
def __init__(self, function, fallback_mapper):
self.function = function
self.fallback_mapper = fallback_mapper
fallback_mapper.rec = self.rec
def map_constant(self, expr, *args):
return self.function(expr, self, *args)
map_variable = map_constant
map_function_symbol = map_constant
map_call = map_constant
map_subscript = map_constant
map_lookup = map_constant
map_negation = map_constant
map_sum = map_constant
map_product = map_constant
map_quotient = map_constant
map_floor_div = map_constant
map_remainder = map_constant
map_power = map_constant
map_polynomial = map_constant
map_list = map_constant
map_tuple = map_constant
map_numpy_array = map_constant
map_common_subexpression = map_constant
map_if_positive = map_constant
# }}}
# {{{ cse caching mixin
class CSECachingMapperMixin(object):
def map_common_subexpression(self, expr):
......@@ -260,3 +369,7 @@ class CSECachingMapperMixin(object):
result = self.map_common_subexpression_uncached(expr)
ccd[expr] = result
return result
# }}}
# vim: foldmethod=marker
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment