diff --git a/src/mapper/constant_folder.py b/src/mapper/constant_folder.py index 8a658d5101613b6896e3786944d5d7f5c153ff47..69acbb2c6e2e407fc826e54cd28163e165662957 100644 --- a/src/mapper/constant_folder.py +++ b/src/mapper/constant_folder.py @@ -5,50 +5,48 @@ from pymbolic.mapper import IdentityMapper, NonrecursiveIdentityMapper class ConstantFoldingMapperBase(object): def fold(self, expr, klass, op, constructor): - from pymbolic import is_constant + from pymbolic import is_constant, evaluate constants = [] nonconstants = [] - queue = expr.children + queue = list(expr.children) while queue: child = self.rec(queue.pop(0)) if isinstance(child, klass): - queue = child.children + queue + queue = list(child.children) + queue else: if is_constant(child): - constants.append(child) + constants.append(evaluate(child)) else: nonconstants.append(child) if constants: import operator constant = reduce(op, constants) - return constructor([constant]+nonconstants) + return constructor(tuple([constant]+nonconstants)) else: - return constructor(nonconstants) + return constructor(tuple(nonconstants)) def map_sum(self, expr): - from pymbolic import sum from pymbolic.primitives import Sum import operator - return self.fold(expr, Sum, operator.add, sum) + return self.fold(expr, Sum, operator.add, Sum) class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase): def map_product(self, expr): - from pymbolic import product from pymbolic.primitives import Product import operator - return self.fold(expr, Product, operator.mul, product) + return self.fold(expr, Product, operator.mul, Product) -class ConstantFoldingMapper(IdentityMapper, ConstantFoldingMapperBase): +class ConstantFoldingMapper(ConstantFoldingMapperBase, IdentityMapper): pass class NonrecursiveConstantFoldingMapper( @@ -56,12 +54,12 @@ class NonrecursiveConstantFoldingMapper( ConstantFoldingMapperBase): pass -class CommutativeConstantFoldingMapper(IdentityMapper, - CommutativeConstantFoldingMapperBase): +class CommutativeConstantFoldingMapper(CommutativeConstantFoldingMapperBase, + IdentityMapper): pass class NonrecursiveCommutativeConstantFoldingMapper( - NonrecursiveIdentityMapper, - CommutativeConstantFoldingMapperBase): + CommutativeConstantFoldingMapperBase, + NonrecursiveIdentityMapper,): pass