diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 5a4dcf745282a0a1e07304a2ddcbebc5d9dac8f6..940f5ea7fd7b9c526f0054472600693329619456 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -754,7 +754,7 @@ def flatten( def unflatten( - template: ArrayOrContainerT, ary: Any, + template: ArrayOrContainerT, ary: Array, actx: ArrayContext, *, strict: bool = True) -> ArrayOrContainerT: """Unflatten an array *ary* produced by :func:`flatten` back into an @@ -822,17 +822,17 @@ def unflatten( # {{{ check strides - if strict and hasattr(template_subary, "strides"): + if strict and hasattr(template_subary_c, "strides"): # Checking strides for 0 sized arrays is ill-defined # since they cannot be indexed if ( # Mypy has a point: nobody promised a .strides attribute. - template_subary_c.strides != subary.strides # type: ignore[attr-defined] # noqa: E501 + template_subary_c.strides != subary.strides and template_subary_c.size != 0 ): raise ValueError( # Mypy has a point: nobody promised a .strides attribute. - f"strides do not match template: got {subary.strides}, " # type: ignore[attr-defined] # noqa: E501 + f"strides do not match template: got {subary.strides}, " f"expected {template_subary_c.strides}") # }}} @@ -849,7 +849,7 @@ def unflatten( f"array context: got '{type(ary).__name__}', expected one of " f"{actx.array_types}") - if ary.ndim != 1: + if len(ary.shape) != 1: raise ValueError( "only one dimensional arrays can be unflattened: " f"'ary' has shape {ary.shape}")