From 112ada652122d37586386cd7e61ec3b6097e85f3 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Wed, 21 Jul 2021 13:47:19 -0500 Subject: [PATCH] add tests for any and all --- test/test_arraycontext.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 0f24cbe..668e320 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -120,6 +120,14 @@ class DOFArray: __array_priority__ = 10 + def __bool__(self): + if len(self) == 1 and self.data[0].size == 1: + return bool(self.data[0]) + + raise ValueError( + "The truth value of an array with more than one element is " + "ambiguous. Use actx.np.any(x) or actx.np.all(x)") + def __len__(self): return len(self.data) @@ -260,6 +268,8 @@ def assert_close_to_numpy_in_containers(actx, op, args): ("where", 3, np.float64), ("min", 1, np.float64), ("max", 1, np.float64), + ("any", 1, np.float64), + ("all", 1, np.float64), # float + complex ("sin", 1, np.float64), @@ -519,6 +529,24 @@ def test_reductions_same_as_numpy(actx_factory, op): assert np.allclose(np_red, actx_red) + +@pytest.mark.parametrize("sym_name", ["any", "all"]) +def test_any_all_same_as_numpy(actx_factory, sym_name): + actx = actx_factory() + if not hasattr(actx.np, sym_name): + pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'") + + rng = np.random.default_rng() + ary_any = rng.integers(0, 2, 512) + ary_all = np.ones(512) + + assert_close_to_numpy_in_containers(actx, + lambda _np, *_args: getattr(_np, sym_name)(*_args), [ary_any]) + assert_close_to_numpy_in_containers(actx, + lambda _np, *_args: getattr(_np, sym_name)(*_args), [ary_all]) + assert_close_to_numpy_in_containers(actx, + lambda _np, *_args: getattr(_np, sym_name)(*_args), [1 - ary_all]) + # }}} -- GitLab