From 5d7f533b14fdbad56016c1e6e401b938a6a45b1d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 28 Feb 2008 14:47:49 -0500 Subject: [PATCH] Fix constant folder (inheritance order broken, etc.). --- src/mapper/constant_folder.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/mapper/constant_folder.py b/src/mapper/constant_folder.py index 8a658d5..69acbb2 100644 --- a/src/mapper/constant_folder.py +++ b/src/mapper/constant_folder.py @@ -5,50 +5,48 @@ from pymbolic.mapper import IdentityMapper, NonrecursiveIdentityMapper class ConstantFoldingMapperBase(object): def fold(self, expr, klass, op, constructor): - from pymbolic import is_constant + from pymbolic import is_constant, evaluate constants = [] nonconstants = [] - queue = expr.children + queue = list(expr.children) while queue: child = self.rec(queue.pop(0)) if isinstance(child, klass): - queue = child.children + queue + queue = list(child.children) + queue else: if is_constant(child): - constants.append(child) + constants.append(evaluate(child)) else: nonconstants.append(child) if constants: import operator constant = reduce(op, constants) - return constructor([constant]+nonconstants) + return constructor(tuple([constant]+nonconstants)) else: - return constructor(nonconstants) + return constructor(tuple(nonconstants)) def map_sum(self, expr): - from pymbolic import sum from pymbolic.primitives import Sum import operator - return self.fold(expr, Sum, operator.add, sum) + return self.fold(expr, Sum, operator.add, Sum) class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase): def map_product(self, expr): - from pymbolic import product from pymbolic.primitives import Product import operator - return self.fold(expr, Product, operator.mul, product) + return self.fold(expr, Product, operator.mul, Product) -class ConstantFoldingMapper(IdentityMapper, ConstantFoldingMapperBase): +class ConstantFoldingMapper(ConstantFoldingMapperBase, IdentityMapper): pass class NonrecursiveConstantFoldingMapper( @@ -56,12 +54,12 @@ class NonrecursiveConstantFoldingMapper( ConstantFoldingMapperBase): pass -class CommutativeConstantFoldingMapper(IdentityMapper, - CommutativeConstantFoldingMapperBase): +class CommutativeConstantFoldingMapper(CommutativeConstantFoldingMapperBase, + IdentityMapper): pass class NonrecursiveCommutativeConstantFoldingMapper( - NonrecursiveIdentityMapper, - CommutativeConstantFoldingMapperBase): + CommutativeConstantFoldingMapperBase, + NonrecursiveIdentityMapper,): pass -- GitLab