diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 309571b3090a0237f67671d8d964e7e6e74b5ae5..72709092d2a9ae2dbb3243b6c615b94f7fd1dda0 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -569,15 +569,14 @@ def unflatten(
         try:
             iterable = serialize_container(template_subary)
         except TypeError:
-            offset += template_subary.size
-            if offset > ary.size:
+            if (offset + template_subary.size) > ary.size:
                 raise ValueError("'template' and 'ary' sizes do not match")
 
             if template_subary.dtype != ary.dtype:
                 raise ValueError("'template' dtype does not match 'ary': "
                         f"got {template_subary.dtype}, expected {ary.dtype}")
 
-            flat_subary = ary[offset - template_subary.size:offset]
+            flat_subary = ary[offset:offset + template_subary.size]
             try:
                 subary = actx.np.reshape(flat_subary,
                         template_subary.shape, order="C")
@@ -597,6 +596,7 @@ def unflatten(
                             f"strides do not match template: got {subary.strides}, "
                             f"expected {template_subary.strides}")
 
+            offset += template_subary.size
             return subary
         else:
             return deserialize_container(template_subary, [