From 85c8b39c8747c042b7eb8c85fcfc139f182f69a5 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Wed, 29 Sep 2021 20:02:27 -0500
Subject: [PATCH] slice forwards when unflattening

---
 arraycontext/container/traversal.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 309571b..7270909 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -569,15 +569,14 @@ def unflatten(
         try:
             iterable = serialize_container(template_subary)
         except TypeError:
-            offset += template_subary.size
-            if offset > ary.size:
+            if (offset + template_subary.size) > ary.size:
                 raise ValueError("'template' and 'ary' sizes do not match")
 
             if template_subary.dtype != ary.dtype:
                 raise ValueError("'template' dtype does not match 'ary': "
                         f"got {template_subary.dtype}, expected {ary.dtype}")
 
-            flat_subary = ary[offset - template_subary.size:offset]
+            flat_subary = ary[offset:offset + template_subary.size]
             try:
                 subary = actx.np.reshape(flat_subary,
                         template_subary.shape, order="C")
@@ -597,6 +596,7 @@ def unflatten(
                             f"strides do not match template: got {subary.strides}, "
                             f"expected {template_subary.strides}")
 
+            offset += template_subary.size
             return subary
         else:
             return deserialize_container(template_subary, [
-- 
GitLab