From a624dfac1c1441898e8686bf7b442291be288add Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 13 May 2008 17:59:38 -0400 Subject: [PATCH] Manage numpy arrays as non-native expressions. (+) Don't id-map unsupported expressions by default. --- src/mapper/__init__.py | 30 +++++++++++++++++++++++++----- src/mapper/constant_folder.py | 3 +++ src/mapper/evaluator.py | 11 +++++++++-- src/mapper/expander.py | 2 ++ 4 files changed, 39 insertions(+), 7 deletions(-) diff --git a/src/mapper/__init__.py b/src/mapper/__init__.py index be24095..691ea71 100644 --- a/src/mapper/__init__.py +++ b/src/mapper/__init__.py @@ -1,3 +1,15 @@ +try: + import numpy + + def is_numpy_array(val): + return isinstance(val, numpy.ndarray) +except ImportError: + def is_numpy_array(ary): + return False + + + + class Mapper(object): def __init__(self, recurse=True): self.Recurse = True @@ -40,6 +52,8 @@ class Mapper(object): return self.map_constant(expr, *args, **kwargs) elif isinstance(expr, list): return self.map_list(expr, *args, **kwargs) + elif is_numpy_array(expr): + return self.map_numpy_array(expr, *args, **kwargs) else: raise ValueError, "encountered invalid foreign object: %s" % repr(expr) @@ -108,6 +122,8 @@ class CombineMapper(RecursiveMapper): map_list = map_sum map_vector = map_sum + def map_numpy_array(self, expr): + return self.combine(expr.flat) @@ -158,17 +174,21 @@ class IdentityMapperBase(object): ((exp, self.rec(coeff, *args, **kwargs)) for exp, coeff in expr.data)) - map_list = map_sum map_vector = map_sum + def map_numpy_array(self, expr): + import numpy + result = numpy.empty(expr.shape, dtype=object) + from pytools import indices_in_shape + for i in indices_in_shape(expr.shape): + result[i] = self.rec(expr[i]) + return result class IdentityMapper(IdentityMapperBase, RecursiveMapper): - def handle_unsupported_expression(self, expr, *args, **kwargs): - return expr + pass class NonrecursiveIdentityMapper(IdentityMapperBase, Mapper): - def handle_unsupported_expression(self, expr, *args, **kwargs): - return expr + pass diff --git a/src/mapper/constant_folder.py b/src/mapper/constant_folder.py index 69acbb2..5da8218 100644 --- a/src/mapper/constant_folder.py +++ b/src/mapper/constant_folder.py @@ -34,6 +34,9 @@ class ConstantFoldingMapperBase(object): return self.fold(expr, Sum, operator.add, Sum) + def handle_unsupported_expression(self, expr): + return expr + class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase): diff --git a/src/mapper/evaluator.py b/src/mapper/evaluator.py index 30cbb55..4ff148c 100644 --- a/src/mapper/evaluator.py +++ b/src/mapper/evaluator.py @@ -68,8 +68,15 @@ class EvaluationMapper(RecursiveMapper): return result def map_list(self, expr): - return [self.rec(child) for child in expr.Children] - + return [self.rec(child) for child in expr] + + def map_numpy_array(self, expr): + import numpy + result = numpy.empty(expr.shape, dtype=object) + from pytools import indices_in_shape + for i in indices_in_shape(expr.shape): + result[i] = self.rec(expr[i]) + return result diff --git a/src/mapper/expander.py b/src/mapper/expander.py index 6978bc7..0822075 100644 --- a/src/mapper/expander.py +++ b/src/mapper/expander.py @@ -66,6 +66,8 @@ class ExpandMapper(IdentityMapper): else: return IdentitityMapper.map_power(expr) + def handle_unsupported_expression(self, expr): + return expr -- GitLab