diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index db3633d17d994fb861a03b5f9c54574aeb966656..309571b3090a0237f67671d8d964e7e6e74b5ae5 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -569,8 +569,7 @@ def unflatten(
         try:
             iterable = serialize_container(template_subary)
         except TypeError:
-            # NOTE: the max is needed to handle device scalars with size == 0
-            offset += max(1, template_subary.size)
+            offset += template_subary.size
             if offset > ary.size:
                 raise ValueError("'template' and 'ary' sizes do not match")
 
@@ -592,7 +591,8 @@ def unflatten(
                         "array context.") from exc
 
             if hasattr(template_subary, "strides"):
-                if template_subary.strides != subary.strides:
+                if template_subary.size != 0 \
+                        and template_subary.strides != subary.strides:
                     raise ValueError(
                             f"strides do not match template: got {subary.strides}, "
                             f"expected {template_subary.strides}")
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 9d27eaae5f4bf3b86a8312c60789f2373ddde716..0ed86251980db4ac4b2fa7d79178ab02fbf21d7b 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -892,13 +892,20 @@ def test_container_norm(actx_factory, ord):
 
 # {{{ test flatten and unflatten
 
-def test_flatten_array_container(actx_factory):
+@pytest.mark.parametrize("shapes", [
+    0,                          # tests device scalars when flattening
+    512,
+    [(128, 67)],
+    [(127, 67), (18, 0)],       # tests 0-sized arrays
+    [(64, 7), (154, 12)]
+    ])
+def test_flatten_array_container(actx_factory, shapes):
     actx = actx_factory()
+    if np.prod(shapes) == 0 and isinstance(actx, PytatoPyOpenCLArrayContext):
+        pytest.skip("operation not supported on PytatoPyOpenCLArrayContext")
 
     from arraycontext import flatten, unflatten
-    arys = _get_test_containers(actx, shapes=512) \
-            + _get_test_containers(actx, shapes=(128, 67)) \
-            + _get_test_containers(actx, shapes=[(64, 7), (154, 12)])
+    arys = _get_test_containers(actx, shapes=shapes)
 
     for ary in arys:
         flat = flatten(ary, actx)
@@ -909,6 +916,12 @@ def test_flatten_array_container(actx_factory):
                 actx.np.linalg.norm(ary - ary_roundtrip)
                 ) < 1.0e-15
 
+        from arraycontext import rec_multimap_reduce_array_container
+        assert rec_multimap_reduce_array_container(
+                np.prod,
+                lambda x, y: x.shape == y.shape,
+                ary, ary_roundtrip)
+
 # }}}