diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 421e5a0868d5b981d2809f8f2a20fec1129c9c2b..8cbd470caf1f85cf8b37da93a63473a92596561c 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -590,8 +590,7 @@ def unflatten(
                         "array context.") from exc
 
             if hasattr(template_subary, "strides"):
-                if template_subary.size != 0 \
-                        and template_subary.strides != subary.strides:
+                if template_subary.strides != subary.strides:
                     raise ValueError(
                             f"strides do not match template: got {subary.strides}, "
                             f"expected {template_subary.strides}")
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index f2047fe6a43aa1d6250500f1f98c3c0f2c356e84..c2be94375a090abd12adb8ad0d1224581761fe34 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -900,9 +900,11 @@ def test_container_norm(actx_factory, ord):
     [(64, 7), (154, 12)]
     ])
 def test_flatten_array_container(actx_factory, shapes):
+    if np.prod(shapes) == 0:
+        # https://github.com/inducer/compyte/pull/36
+        pytest.xfail("strides do not match in subary")
+
     actx = actx_factory()
-    if np.prod(shapes) == 0 and isinstance(actx, PytatoPyOpenCLArrayContext):
-        pytest.skip("operation not supported on PytatoPyOpenCLArrayContext")
 
     from arraycontext import flatten, unflatten
     arys = _get_test_containers(actx, shapes=shapes)
@@ -912,9 +914,6 @@ def test_flatten_array_container(actx_factory, shapes):
         assert flat.ndim == 1
 
         ary_roundtrip = unflatten(ary, flat, actx)
-        assert actx.to_numpy(
-                actx.np.linalg.norm(ary - ary_roundtrip)
-                ) < 1.0e-15
 
         from arraycontext import rec_multimap_reduce_array_container
         assert rec_multimap_reduce_array_container(
@@ -922,6 +921,10 @@ def test_flatten_array_container(actx_factory, shapes):
                 lambda x, y: x.shape == y.shape,
                 ary, ary_roundtrip)
 
+        assert actx.to_numpy(
+                actx.np.linalg.norm(ary - ary_roundtrip)
+                ) < 1.0e-15
+
 
 def test_flatten_array_container_failure(actx_factory):
     actx = actx_factory()