diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 08d9ba9fccf555baa0f7e155caa00a88d65da3a2..e7bf7312b504e497ca529a803719e598ac922030 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -582,6 +582,12 @@ def unflatten( "functionality that is not implemented in the array " f"context '{type(actx).__name__}'") from exc + if hasattr(template_subary, "strides"): + if template_subary.strides != subary.strides: + raise ValueError( + f"strides do not match template: got {subary.strides}, " + f"expected {template_subary.strides}") + return subary else: return deserialize_container(template_subary, [