From bc7139f85d8933ac81e4bb326d514f49fa4cc26a Mon Sep 17 00:00:00 2001 From: nkoskelo <129830924+nkoskelo@users.noreply.github.com> Date: Thu, 9 Jan 2025 21:58:36 +0000 Subject: [PATCH] Add an implementation of np.vdot to PytatoPyOpenCLArrayContext (#299) * Add an implementation of vdot to the PytatoPyOpenCLArrayContext np namespace. * Remove the tests that are just skipped for scalars. * Respond to comments. * Ruff version needed to be updated locally. --- arraycontext/impl/pytato/fake_numpy.py | 3 +++ test/test_arraycontext.py | 9 ++++----- test/test_utils.py | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index d707285..21dc71e 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -239,4 +239,7 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): def absolute(self, a): return self.abs(a) + def vdot(self, a: Array, b: Array): + + return rec_multimap_array_container(pt.vdot, a, b) # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 14d24dd..ad2cbb1 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -271,11 +271,10 @@ def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype): assert_close_to_numpy_in_containers(actx, evaluate, args) - if sym_name in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]: - pytest.skip(f"'{sym_name}' not supported on scalars") - - args = [randn(0, dtype)[()] for i in range(n_args)] - assert_close_to_numpy(actx, evaluate, args) + if sym_name not in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]: + # Scalar arguments are supported. + args = [randn(0, dtype)[()] for i in range(n_args)] + assert_close_to_numpy(actx, evaluate, args) @pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [ diff --git a/test/test_utils.py b/test/test_utils.py index 807d652..3b74a42 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -27,7 +27,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import logging -from typing import Optional, cast +from typing import cast import numpy as np import pytest @@ -63,7 +63,7 @@ def test_dataclass_array_container() -> None: class ArrayContainerWithOptional: x: np.ndarray # Deliberately left as Optional to test compatibility. - y: Optional[np.ndarray] # noqa: UP007 + y: np.ndarray | None with pytest.raises(TypeError, match="Field 'y' union contains non-array"): # NOTE: cannot have wrapped annotations (here by `Optional`) -- GitLab