From 7b03174b6ede7e540c663676de1961558a2b2c48 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Thu, 30 Sep 2021 19:58:35 -0500
Subject: [PATCH] update test skip condition

---
 arraycontext/container/traversal.py |  3 +--
 test/test_arraycontext.py           | 13 ++++++++-----
 2 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 421e5a0..8cbd470 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -590,8 +590,7 @@ def unflatten(
                         "array context.") from exc
 
             if hasattr(template_subary, "strides"):
-                if template_subary.size != 0 \
-                        and template_subary.strides != subary.strides:
+                if template_subary.strides != subary.strides:
                     raise ValueError(
                             f"strides do not match template: got {subary.strides}, "
                             f"expected {template_subary.strides}")
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index f2047fe..c2be943 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -900,9 +900,11 @@ def test_container_norm(actx_factory, ord):
     [(64, 7), (154, 12)]
     ])
 def test_flatten_array_container(actx_factory, shapes):
+    if np.prod(shapes) == 0:
+        # https://github.com/inducer/compyte/pull/36
+        pytest.xfail("strides do not match in subary")
+
     actx = actx_factory()
-    if np.prod(shapes) == 0 and isinstance(actx, PytatoPyOpenCLArrayContext):
-        pytest.skip("operation not supported on PytatoPyOpenCLArrayContext")
 
     from arraycontext import flatten, unflatten
     arys = _get_test_containers(actx, shapes=shapes)
@@ -912,9 +914,6 @@ def test_flatten_array_container(actx_factory, shapes):
         assert flat.ndim == 1
 
         ary_roundtrip = unflatten(ary, flat, actx)
-        assert actx.to_numpy(
-                actx.np.linalg.norm(ary - ary_roundtrip)
-                ) < 1.0e-15
 
         from arraycontext import rec_multimap_reduce_array_container
         assert rec_multimap_reduce_array_container(
@@ -922,6 +921,10 @@ def test_flatten_array_container(actx_factory, shapes):
                 lambda x, y: x.shape == y.shape,
                 ary, ary_roundtrip)
 
+        assert actx.to_numpy(
+                actx.np.linalg.norm(ary - ary_roundtrip)
+                ) < 1.0e-15
+
 
 def test_flatten_array_container_failure(actx_factory):
     actx = actx_factory()
-- 
GitLab