From 5dc9da2c3f119455a5e796996feb27117d078d0b Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Wed, 29 Sep 2021 20:38:26 -0500
Subject: [PATCH] add some more tests for unflatten

---
 test/test_arraycontext.py | 20 ++++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 0ed8625..f2047fe 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -922,6 +922,26 @@ def test_flatten_array_container(actx_factory, shapes):
                 lambda x, y: x.shape == y.shape,
                 ary, ary_roundtrip)
 
+
+def test_flatten_array_container_failure(actx_factory):
+    actx = actx_factory()
+
+    from arraycontext import flatten, unflatten
+    ary = _get_test_containers(actx, shapes=512)[0]
+    flat_ary = flatten(ary, actx)
+
+    with pytest.raises(TypeError):
+        # cannot unflatten from a numpy array
+        unflatten(ary, actx.to_numpy(flat_ary), actx)
+
+    with pytest.raises(ValueError):
+        # cannot unflatten non-flat arrays
+        unflatten(ary, flat_ary.reshape(2, -1), actx)
+
+    with pytest.raises(ValueError):
+        # cannot unflatten partially
+        unflatten(ary, flat_ary[:-1], actx)
+
 # }}}
 
 
-- 
GitLab