diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 678b3c6bfb2f210190328fd365c18d24b7393895..686e682b4816c805f44f5391b7d055a43753d730 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -63,6 +63,8 @@ class Mapper(object): return self.map_constant(expr, *args) elif isinstance(expr, list): return self.map_list(expr, *args) + elif isinstance(expr, tuple): + return self.map_tuple(expr, *args) elif is_numpy_array(expr): return self.map_numpy_array(expr, *args) else: @@ -138,8 +140,10 @@ class CombineMapper(RecursiveMapper): self.rec(coeff, *args) for exp, coeff in expr.data) ) - map_list = map_sum - map_vector = map_sum + def map_list(self, expr, *args): + return self.combine(self.rec(child, *args) for child in expr) + + map_tuple = map_list def map_numpy_array(self, expr, *args): return self.combine(self.rec(el) for el in expr.flat) @@ -153,11 +157,6 @@ class CombineMapper(RecursiveMapper): self.rec(expr.then), self.rec(expr.else_)]) - def map_foreign(self, expr, *args): - if isinstance(expr, tuple): - return self.combine([self.rec(child) for child in expr]) - else: - return RecursiveMapper.map_foreign(self, expr, *args) @@ -215,8 +214,11 @@ class IdentityMapperBase(object): ((exp, self.rec(coeff, *args)) for exp, coeff in expr.data)) - map_list = map_sum - map_vector = map_sum + def map_list(self, expr, *args): + return [self.rec(child, *args) for child in expr] + + def map_tuple(self, expr, *args): + return tuple(self.rec(child, *args) for child in expr) def map_numpy_array(self, expr): import numpy