diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 309571b3090a0237f67671d8d964e7e6e74b5ae5..72709092d2a9ae2dbb3243b6c615b94f7fd1dda0 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -569,15 +569,14 @@ def unflatten( try: iterable = serialize_container(template_subary) except TypeError: - offset += template_subary.size - if offset > ary.size: + if (offset + template_subary.size) > ary.size: raise ValueError("'template' and 'ary' sizes do not match") if template_subary.dtype != ary.dtype: raise ValueError("'template' dtype does not match 'ary': " f"got {template_subary.dtype}, expected {ary.dtype}") - flat_subary = ary[offset - template_subary.size:offset] + flat_subary = ary[offset:offset + template_subary.size] try: subary = actx.np.reshape(flat_subary, template_subary.shape, order="C") @@ -597,6 +596,7 @@ def unflatten( f"strides do not match template: got {subary.strides}, " f"expected {template_subary.strides}") + offset += template_subary.size return subary else: return deserialize_container(template_subary, [