From 112ada652122d37586386cd7e61ec3b6097e85f3 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Wed, 21 Jul 2021 13:47:19 -0500
Subject: [PATCH] add tests for any and all

---
 test/test_arraycontext.py | 28 ++++++++++++++++++++++++++++
 1 file changed, 28 insertions(+)

diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 0f24cbe..668e320 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -120,6 +120,14 @@ class DOFArray:
 
     __array_priority__ = 10
 
+    def __bool__(self):
+        if len(self) == 1 and self.data[0].size == 1:
+            return bool(self.data[0])
+
+        raise ValueError(
+                "The truth value of an array with more than one element is "
+                "ambiguous. Use actx.np.any(x) or actx.np.all(x)")
+
     def __len__(self):
         return len(self.data)
 
@@ -260,6 +268,8 @@ def assert_close_to_numpy_in_containers(actx, op, args):
             ("where", 3, np.float64),
             ("min", 1, np.float64),
             ("max", 1, np.float64),
+            ("any", 1, np.float64),
+            ("all", 1, np.float64),
 
             # float + complex
             ("sin", 1, np.float64),
@@ -519,6 +529,24 @@ def test_reductions_same_as_numpy(actx_factory, op):
 
     assert np.allclose(np_red, actx_red)
 
+
+@pytest.mark.parametrize("sym_name", ["any", "all"])
+def test_any_all_same_as_numpy(actx_factory, sym_name):
+    actx = actx_factory()
+    if not hasattr(actx.np, sym_name):
+        pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'")
+
+    rng = np.random.default_rng()
+    ary_any = rng.integers(0, 2, 512)
+    ary_all = np.ones(512)
+
+    assert_close_to_numpy_in_containers(actx,
+                lambda _np, *_args: getattr(_np, sym_name)(*_args), [ary_any])
+    assert_close_to_numpy_in_containers(actx,
+                lambda _np, *_args: getattr(_np, sym_name)(*_args), [ary_all])
+    assert_close_to_numpy_in_containers(actx,
+                lambda _np, *_args: getattr(_np, sym_name)(*_args), [1 - ary_all])
+
 # }}}
 
 
-- 
GitLab