From 087351c62eb4ccd68f90990ed8de8317a68c6d29 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Wed, 29 Sep 2021 20:01:31 -0500
Subject: [PATCH] add tests for flatten edge cases

---
 arraycontext/container/traversal.py |  6 +++---
 test/test_arraycontext.py           | 21 +++++++++++++++++----
 2 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index db3633d..309571b 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -569,8 +569,7 @@ def unflatten(
         try:
             iterable = serialize_container(template_subary)
         except TypeError:
-            # NOTE: the max is needed to handle device scalars with size == 0
-            offset += max(1, template_subary.size)
+            offset += template_subary.size
             if offset > ary.size:
                 raise ValueError("'template' and 'ary' sizes do not match")
 
@@ -592,7 +591,8 @@ def unflatten(
                         "array context.") from exc
 
             if hasattr(template_subary, "strides"):
-                if template_subary.strides != subary.strides:
+                if template_subary.size != 0 \
+                        and template_subary.strides != subary.strides:
                     raise ValueError(
                             f"strides do not match template: got {subary.strides}, "
                             f"expected {template_subary.strides}")
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 9d27eaa..0ed8625 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -892,13 +892,20 @@ def test_container_norm(actx_factory, ord):
 
 # {{{ test flatten and unflatten
 
-def test_flatten_array_container(actx_factory):
+@pytest.mark.parametrize("shapes", [
+    0,                          # tests device scalars when flattening
+    512,
+    [(128, 67)],
+    [(127, 67), (18, 0)],       # tests 0-sized arrays
+    [(64, 7), (154, 12)]
+    ])
+def test_flatten_array_container(actx_factory, shapes):
     actx = actx_factory()
+    if np.prod(shapes) == 0 and isinstance(actx, PytatoPyOpenCLArrayContext):
+        pytest.skip("operation not supported on PytatoPyOpenCLArrayContext")
 
     from arraycontext import flatten, unflatten
-    arys = _get_test_containers(actx, shapes=512) \
-            + _get_test_containers(actx, shapes=(128, 67)) \
-            + _get_test_containers(actx, shapes=[(64, 7), (154, 12)])
+    arys = _get_test_containers(actx, shapes=shapes)
 
     for ary in arys:
         flat = flatten(ary, actx)
@@ -909,6 +916,12 @@ def test_flatten_array_container(actx_factory):
                 actx.np.linalg.norm(ary - ary_roundtrip)
                 ) < 1.0e-15
 
+        from arraycontext import rec_multimap_reduce_array_container
+        assert rec_multimap_reduce_array_container(
+                np.prod,
+                lambda x, y: x.shape == y.shape,
+                ary, ary_roundtrip)
+
 # }}}
 
 
-- 
GitLab