From 38beeed5cdce9426f1c0d09f13d59d11e6219abf Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 21 Oct 2012 20:52:21 -0400 Subject: [PATCH] Make sure flattened_sum() doesn't do constant folding by accident. --- pymbolic/primitives.py | 28 ++++++++++++++++++++-------- test/test_pymbolic.py | 11 +++++++++++ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 2d67adb..6af60d8 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -790,15 +790,27 @@ def subscript(expression, index): def flattened_sum(components): - it = components.__iter__() - try: - result = it.next() - except StopIteration: - return 0 + # flatten any potential sub-sums + queue = list(components) + done = [] + + while queue: + item = queue.pop(0) + + if is_zero(item): + continue - for i in it: - result = result + i - return result + if isinstance(item, Sum): + queue += item.children + else: + done.append(item) + + if len(done) == 0: + return 0 + elif len(done) == 1: + return done[0] + else: + return Sum(tuple(done)) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 90f6862..8fc5a08 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -163,6 +163,17 @@ def test_parser(): +def test_structure_preservation(): + x = prim.Sum((5, 7)) + from pymbolic.mapper import IdentityMapper + x2 = IdentityMapper()(x) + assert x == x2 + + + + + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab