diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py
index 2d67adbde42b66df19322c1ea8ec6c4407354726..6af60d8bc713e4e848e4c008c4a05b7eccc289a9 100644
--- a/pymbolic/primitives.py
+++ b/pymbolic/primitives.py
@@ -790,15 +790,27 @@ def subscript(expression, index):
 
 
 def flattened_sum(components):
-    it = components.__iter__()
-    try:
-        result = it.next()
-    except StopIteration:
-        return 0
+    # flatten any potential sub-sums
+    queue = list(components)
+    done = []
+
+    while queue:
+        item = queue.pop(0)
+
+        if is_zero(item):
+            continue
 
-    for i in it:
-        result = result + i
-    return result
+        if isinstance(item, Sum):
+            queue += item.children
+        else:
+            done.append(item)
+
+    if len(done) == 0:
+        return 0
+    elif len(done) == 1:
+        return done[0]
+    else:
+        return Sum(tuple(done))
 
 
 
diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py
index 90f6862287f02b683b4d960d6c149f9217d62af4..8fc5a08b17d3dba9f39cf63c59ad5eafa03324cd 100644
--- a/test/test_pymbolic.py
+++ b/test/test_pymbolic.py
@@ -163,6 +163,17 @@ def test_parser():
 
 
 
+def test_structure_preservation():
+    x = prim.Sum((5, 7))
+    from pymbolic.mapper import IdentityMapper
+    x2 = IdentityMapper()(x)
+    assert x == x2
+
+
+
+
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1: