From 7b03174b6ede7e540c663676de1961558a2b2c48 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Thu, 30 Sep 2021 19:58:35 -0500 Subject: [PATCH] update test skip condition --- arraycontext/container/traversal.py | 3 +-- test/test_arraycontext.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 421e5a0..8cbd470 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -590,8 +590,7 @@ def unflatten( "array context.") from exc if hasattr(template_subary, "strides"): - if template_subary.size != 0 \ - and template_subary.strides != subary.strides: + if template_subary.strides != subary.strides: raise ValueError( f"strides do not match template: got {subary.strides}, " f"expected {template_subary.strides}") diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index f2047fe..c2be943 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -900,9 +900,11 @@ def test_container_norm(actx_factory, ord): [(64, 7), (154, 12)] ]) def test_flatten_array_container(actx_factory, shapes): + if np.prod(shapes) == 0: + # https://github.com/inducer/compyte/pull/36 + pytest.xfail("strides do not match in subary") + actx = actx_factory() - if np.prod(shapes) == 0 and isinstance(actx, PytatoPyOpenCLArrayContext): - pytest.skip("operation not supported on PytatoPyOpenCLArrayContext") from arraycontext import flatten, unflatten arys = _get_test_containers(actx, shapes=shapes) @@ -912,9 +914,6 @@ def test_flatten_array_container(actx_factory, shapes): assert flat.ndim == 1 ary_roundtrip = unflatten(ary, flat, actx) - assert actx.to_numpy( - actx.np.linalg.norm(ary - ary_roundtrip) - ) < 1.0e-15 from arraycontext import rec_multimap_reduce_array_container assert rec_multimap_reduce_array_container( @@ -922,6 +921,10 @@ def test_flatten_array_container(actx_factory, shapes): lambda x, y: x.shape == y.shape, ary, ary_roundtrip) + assert actx.to_numpy( + actx.np.linalg.norm(ary - ary_roundtrip) + ) < 1.0e-15 + def test_flatten_array_container_failure(actx_factory): actx = actx_factory() -- GitLab