From 85c8b39c8747c042b7eb8c85fcfc139f182f69a5 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Wed, 29 Sep 2021 20:02:27 -0500 Subject: [PATCH] slice forwards when unflattening --- arraycontext/container/traversal.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 309571b..7270909 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -569,15 +569,14 @@ def unflatten( try: iterable = serialize_container(template_subary) except TypeError: - offset += template_subary.size - if offset > ary.size: + if (offset + template_subary.size) > ary.size: raise ValueError("'template' and 'ary' sizes do not match") if template_subary.dtype != ary.dtype: raise ValueError("'template' dtype does not match 'ary': " f"got {template_subary.dtype}, expected {ary.dtype}") - flat_subary = ary[offset - template_subary.size:offset] + flat_subary = ary[offset:offset + template_subary.size] try: subary = actx.np.reshape(flat_subary, template_subary.shape, order="C") @@ -597,6 +596,7 @@ def unflatten( f"strides do not match template: got {subary.strides}, " f"expected {template_subary.strides}") + offset += template_subary.size return subary else: return deserialize_container(template_subary, [ -- GitLab