diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index c3bf32fe1ec6bc941cf52f26030d36aef952eaf4..dc608807828e113f4589d2ac7edde0486266e1bd 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -655,25 +655,26 @@ def unflatten_from_numpy( # NOTE: https://github.com/python/mypy/issues/7057 offset = 0 - def _unflatten_from_numpy(subary: ArrayOrContainerT) -> ArrayOrContainerT: + def _unflatten_from_numpy( + template_subary: ArrayOrContainerT) -> ArrayOrContainerT: nonlocal offset try: - iterable = serialize_container(subary) + iterable = serialize_container(template_subary) except TypeError: # NOTE: the max is needed to handle device scalars with size == 0 - offset += max(1, subary.size) + offset += max(1, template_subary.size) if offset > ary.size: raise ValueError("'template' and 'ary' sizes do not match") # FIXME: subary can be F-contiguous and ary will always be C-contiguous return actx.from_numpy( - ary[offset - subary.size:offset] - .astype(subary.dtype, copy=False) - .reshape(subary.shape) + ary[offset - template_subary.size:offset] + .astype(template_subary.dtype, copy=False) + .reshape(template_subary.shape) ) else: - return deserialize_container(subary, [ + return deserialize_container(template_subary, [ (key, _unflatten_from_numpy(isubary)) for key, isubary in iterable ])