diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 421e5a0868d5b981d2809f8f2a20fec1129c9c2b..8cbd470caf1f85cf8b37da93a63473a92596561c 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -590,8 +590,7 @@ def unflatten( "array context.") from exc if hasattr(template_subary, "strides"): - if template_subary.size != 0 \ - and template_subary.strides != subary.strides: + if template_subary.strides != subary.strides: raise ValueError( f"strides do not match template: got {subary.strides}, " f"expected {template_subary.strides}") diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index f2047fe6a43aa1d6250500f1f98c3c0f2c356e84..c2be94375a090abd12adb8ad0d1224581761fe34 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -900,9 +900,11 @@ def test_container_norm(actx_factory, ord): [(64, 7), (154, 12)] ]) def test_flatten_array_container(actx_factory, shapes): + if np.prod(shapes) == 0: + # https://github.com/inducer/compyte/pull/36 + pytest.xfail("strides do not match in subary") + actx = actx_factory() - if np.prod(shapes) == 0 and isinstance(actx, PytatoPyOpenCLArrayContext): - pytest.skip("operation not supported on PytatoPyOpenCLArrayContext") from arraycontext import flatten, unflatten arys = _get_test_containers(actx, shapes=shapes) @@ -912,9 +914,6 @@ def test_flatten_array_container(actx_factory, shapes): assert flat.ndim == 1 ary_roundtrip = unflatten(ary, flat, actx) - assert actx.to_numpy( - actx.np.linalg.norm(ary - ary_roundtrip) - ) < 1.0e-15 from arraycontext import rec_multimap_reduce_array_container assert rec_multimap_reduce_array_container( @@ -922,6 +921,10 @@ def test_flatten_array_container(actx_factory, shapes): lambda x, y: x.shape == y.shape, ary, ary_roundtrip) + assert actx.to_numpy( + actx.np.linalg.norm(ary - ary_roundtrip) + ) < 1.0e-15 + def test_flatten_array_container_failure(actx_factory): actx = actx_factory()