From 33884ae50cfb13c01e434e6d92483ca9e386cdf2 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Sun, 7 Nov 2021 09:54:04 -0600
Subject: [PATCH] expand tests to actually check scalars

---
 arraycontext/fake_numpy.py |  3 ++-
 test/test_arraycontext.py  | 20 ++++++++++++++++++--
 2 files changed, 20 insertions(+), 3 deletions(-)

diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py
index 251a5a2..c75295c 100644
--- a/arraycontext/fake_numpy.py
+++ b/arraycontext/fake_numpy.py
@@ -170,7 +170,8 @@ class BaseFakeNumpyNamespace:
         c_name = self._numpy_to_c_arc_functions.get(name, name)
 
         # limit which functions we try to hand off to loopy
-        if name in self._numpy_math_functions:
+        if (name in self._numpy_math_functions
+                or name in self._c_to_numpy_arc_functions):
             return multimapped_over_array_containers(loopy_implemented_elwise_func)
         else:
             raise AttributeError(name)
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 39e49b2..75a91cb 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -308,8 +308,24 @@ def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype):
     ndofs = 512
     args = [randn(ndofs, dtype) for i in range(n_args)]
 
-    assert_close_to_numpy_in_containers(
-            actx, lambda _np, *_args: getattr(_np, sym_name)(*_args), args)
+    c_to_numpy_arc_functions = {
+            "atan": "arctan",
+            "atan2": "arctan2",
+            }
+
+    def evaluate(_np, *_args):
+        func = getattr(_np, sym_name,
+                getattr(_np, c_to_numpy_arc_functions.get(sym_name, sym_name)))
+
+        return func(*_args)
+
+    assert_close_to_numpy_in_containers(actx, evaluate, args)
+
+    if sym_name in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]:
+        pytest.skip(f"'{sym_name}' not supported on scalars")
+
+    args = [randn(0, dtype)[()] for i in range(n_args)]
+    assert_close_to_numpy(actx, evaluate, args)
 
 
 @pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [
-- 
GitLab