diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 0ed86251980db4ac4b2fa7d79178ab02fbf21d7b..f2047fe6a43aa1d6250500f1f98c3c0f2c356e84 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -922,6 +922,26 @@ def test_flatten_array_container(actx_factory, shapes): lambda x, y: x.shape == y.shape, ary, ary_roundtrip) + +def test_flatten_array_container_failure(actx_factory): + actx = actx_factory() + + from arraycontext import flatten, unflatten + ary = _get_test_containers(actx, shapes=512)[0] + flat_ary = flatten(ary, actx) + + with pytest.raises(TypeError): + # cannot unflatten from a numpy array + unflatten(ary, actx.to_numpy(flat_ary), actx) + + with pytest.raises(ValueError): + # cannot unflatten non-flat arrays + unflatten(ary, flat_ary.reshape(2, -1), actx) + + with pytest.raises(ValueError): + # cannot unflatten partially + unflatten(ary, flat_ary[:-1], actx) + # }}}