diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 2d67adbde42b66df19322c1ea8ec6c4407354726..6af60d8bc713e4e848e4c008c4a05b7eccc289a9 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -790,15 +790,27 @@ def subscript(expression, index): def flattened_sum(components): - it = components.__iter__() - try: - result = it.next() - except StopIteration: - return 0 + # flatten any potential sub-sums + queue = list(components) + done = [] + + while queue: + item = queue.pop(0) + + if is_zero(item): + continue - for i in it: - result = result + i - return result + if isinstance(item, Sum): + queue += item.children + else: + done.append(item) + + if len(done) == 0: + return 0 + elif len(done) == 1: + return done[0] + else: + return Sum(tuple(done)) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 90f6862287f02b683b4d960d6c149f9217d62af4..8fc5a08b17d3dba9f39cf63c59ad5eafa03324cd 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -163,6 +163,17 @@ def test_parser(): +def test_structure_preservation(): + x = prim.Sum((5, 7)) + from pymbolic.mapper import IdentityMapper + x2 = IdentityMapper()(x) + assert x == x2 + + + + + + if __name__ == "__main__": import sys if len(sys.argv) > 1: