diff --git a/src/mapper/__init__.py b/src/mapper/__init__.py index be2409532ed0ec16c41cd5aed2d79b77ea18c865..691ea71f188ca159e67750e81c2e15d29583d77d 100644 --- a/src/mapper/__init__.py +++ b/src/mapper/__init__.py @@ -1,3 +1,15 @@ +try: + import numpy + + def is_numpy_array(val): + return isinstance(val, numpy.ndarray) +except ImportError: + def is_numpy_array(ary): + return False + + + + class Mapper(object): def __init__(self, recurse=True): self.Recurse = True @@ -40,6 +52,8 @@ class Mapper(object): return self.map_constant(expr, *args, **kwargs) elif isinstance(expr, list): return self.map_list(expr, *args, **kwargs) + elif is_numpy_array(expr): + return self.map_numpy_array(expr, *args, **kwargs) else: raise ValueError, "encountered invalid foreign object: %s" % repr(expr) @@ -108,6 +122,8 @@ class CombineMapper(RecursiveMapper): map_list = map_sum map_vector = map_sum + def map_numpy_array(self, expr): + return self.combine(expr.flat) @@ -158,17 +174,21 @@ class IdentityMapperBase(object): ((exp, self.rec(coeff, *args, **kwargs)) for exp, coeff in expr.data)) - map_list = map_sum map_vector = map_sum + def map_numpy_array(self, expr): + import numpy + result = numpy.empty(expr.shape, dtype=object) + from pytools import indices_in_shape + for i in indices_in_shape(expr.shape): + result[i] = self.rec(expr[i]) + return result class IdentityMapper(IdentityMapperBase, RecursiveMapper): - def handle_unsupported_expression(self, expr, *args, **kwargs): - return expr + pass class NonrecursiveIdentityMapper(IdentityMapperBase, Mapper): - def handle_unsupported_expression(self, expr, *args, **kwargs): - return expr + pass diff --git a/src/mapper/constant_folder.py b/src/mapper/constant_folder.py index 69acbb2c6e2e407fc826e54cd28163e165662957..5da8218bbb1b1a54f6a47b24f881c0e8f13c4549 100644 --- a/src/mapper/constant_folder.py +++ b/src/mapper/constant_folder.py @@ -34,6 +34,9 @@ class ConstantFoldingMapperBase(object): return self.fold(expr, Sum, operator.add, Sum) + def handle_unsupported_expression(self, expr): + return expr + class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase): diff --git a/src/mapper/evaluator.py b/src/mapper/evaluator.py index 30cbb55f7757a0c2e582190c6ace59a36fb30718..4ff148c735ddfbc6583e5846afed838b5d5a4a1d 100644 --- a/src/mapper/evaluator.py +++ b/src/mapper/evaluator.py @@ -68,8 +68,15 @@ class EvaluationMapper(RecursiveMapper): return result def map_list(self, expr): - return [self.rec(child) for child in expr.Children] - + return [self.rec(child) for child in expr] + + def map_numpy_array(self, expr): + import numpy + result = numpy.empty(expr.shape, dtype=object) + from pytools import indices_in_shape + for i in indices_in_shape(expr.shape): + result[i] = self.rec(expr[i]) + return result diff --git a/src/mapper/expander.py b/src/mapper/expander.py index 6978bc76b20765c63447bbcd2f8a9eef539fe2d9..0822075333fa286a1c9792f5e29fc5acfa688306 100644 --- a/src/mapper/expander.py +++ b/src/mapper/expander.py @@ -66,6 +66,8 @@ class ExpandMapper(IdentityMapper): else: return IdentitityMapper.map_power(expr) + def handle_unsupported_expression(self, expr): + return expr