From 33884ae50cfb13c01e434e6d92483ca9e386cdf2 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Sun, 7 Nov 2021 09:54:04 -0600 Subject: [PATCH] expand tests to actually check scalars --- arraycontext/fake_numpy.py | 3 ++- test/test_arraycontext.py | 20 ++++++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 251a5a2..c75295c 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -170,7 +170,8 @@ class BaseFakeNumpyNamespace: c_name = self._numpy_to_c_arc_functions.get(name, name) # limit which functions we try to hand off to loopy - if name in self._numpy_math_functions: + if (name in self._numpy_math_functions + or name in self._c_to_numpy_arc_functions): return multimapped_over_array_containers(loopy_implemented_elwise_func) else: raise AttributeError(name) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 39e49b2..75a91cb 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -308,8 +308,24 @@ def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype): ndofs = 512 args = [randn(ndofs, dtype) for i in range(n_args)] - assert_close_to_numpy_in_containers( - actx, lambda _np, *_args: getattr(_np, sym_name)(*_args), args) + c_to_numpy_arc_functions = { + "atan": "arctan", + "atan2": "arctan2", + } + + def evaluate(_np, *_args): + func = getattr(_np, sym_name, + getattr(_np, c_to_numpy_arc_functions.get(sym_name, sym_name))) + + return func(*_args) + + assert_close_to_numpy_in_containers(actx, evaluate, args) + + if sym_name in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]: + pytest.skip(f"'{sym_name}' not supported on scalars") + + args = [randn(0, dtype)[()] for i in range(n_args)] + assert_close_to_numpy(actx, evaluate, args) @pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [ -- GitLab