From bc7139f85d8933ac81e4bb326d514f49fa4cc26a Mon Sep 17 00:00:00 2001
From: nkoskelo <129830924+nkoskelo@users.noreply.github.com>
Date: Thu, 9 Jan 2025 21:58:36 +0000
Subject: [PATCH] Add an implementation of np.vdot to
 PytatoPyOpenCLArrayContext (#299)

* Add an implementation of vdot to the PytatoPyOpenCLArrayContext np namespace.

* Remove the tests that are just skipped for scalars.

* Respond to comments.

* Ruff version needed to be updated locally.
---
 arraycontext/impl/pytato/fake_numpy.py | 3 +++
 test/test_arraycontext.py              | 9 ++++-----
 test/test_utils.py                     | 4 ++--
 3 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py
index d707285..21dc71e 100644
--- a/arraycontext/impl/pytato/fake_numpy.py
+++ b/arraycontext/impl/pytato/fake_numpy.py
@@ -239,4 +239,7 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
     def absolute(self, a):
         return self.abs(a)
 
+    def vdot(self, a: Array, b: Array):
+
+        return rec_multimap_array_container(pt.vdot, a, b)
     # }}}
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 14d24dd..ad2cbb1 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -271,11 +271,10 @@ def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype):
 
     assert_close_to_numpy_in_containers(actx, evaluate, args)
 
-    if sym_name in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]:
-        pytest.skip(f"'{sym_name}' not supported on scalars")
-
-    args = [randn(0, dtype)[()] for i in range(n_args)]
-    assert_close_to_numpy(actx, evaluate, args)
+    if sym_name not in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]:
+        # Scalar arguments are supported.
+        args = [randn(0, dtype)[()] for i in range(n_args)]
+        assert_close_to_numpy(actx, evaluate, args)
 
 
 @pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [
diff --git a/test/test_utils.py b/test/test_utils.py
index 807d652..3b74a42 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -27,7 +27,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 import logging
-from typing import Optional, cast
+from typing import cast
 
 import numpy as np
 import pytest
@@ -63,7 +63,7 @@ def test_dataclass_array_container() -> None:
     class ArrayContainerWithOptional:
         x: np.ndarray
         # Deliberately left as Optional to test compatibility.
-        y: Optional[np.ndarray]  # noqa: UP007
+        y: np.ndarray | None
 
     with pytest.raises(TypeError, match="Field 'y' union contains non-array"):
         # NOTE: cannot have wrapped annotations (here by `Optional`)
-- 
GitLab