From 38beeed5cdce9426f1c0d09f13d59d11e6219abf Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 21 Oct 2012 20:52:21 -0400
Subject: [PATCH] Make sure flattened_sum() doesn't do constant folding by
 accident.

---
 pymbolic/primitives.py | 28 ++++++++++++++++++++--------
 test/test_pymbolic.py  | 11 +++++++++++
 2 files changed, 31 insertions(+), 8 deletions(-)

diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py
index 2d67adb..6af60d8 100644
--- a/pymbolic/primitives.py
+++ b/pymbolic/primitives.py
@@ -790,15 +790,27 @@ def subscript(expression, index):
 
 
 def flattened_sum(components):
-    it = components.__iter__()
-    try:
-        result = it.next()
-    except StopIteration:
-        return 0
+    # flatten any potential sub-sums
+    queue = list(components)
+    done = []
+
+    while queue:
+        item = queue.pop(0)
+
+        if is_zero(item):
+            continue
 
-    for i in it:
-        result = result + i
-    return result
+        if isinstance(item, Sum):
+            queue += item.children
+        else:
+            done.append(item)
+
+    if len(done) == 0:
+        return 0
+    elif len(done) == 1:
+        return done[0]
+    else:
+        return Sum(tuple(done))
 
 
 
diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py
index 90f6862..8fc5a08 100644
--- a/test/test_pymbolic.py
+++ b/test/test_pymbolic.py
@@ -163,6 +163,17 @@ def test_parser():
 
 
 
+def test_structure_preservation():
+    x = prim.Sum((5, 7))
+    from pymbolic.mapper import IdentityMapper
+    x2 = IdentityMapper()(x)
+    assert x == x2
+
+
+
+
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab