diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index d7072855a1a3b1f0bc463bf0b7f63b0bedf6180c..21dc71edb2288e6a5267d0608b25fbdc82e48010 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -239,4 +239,7 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): def absolute(self, a): return self.abs(a) + def vdot(self, a: Array, b: Array): + + return rec_multimap_array_container(pt.vdot, a, b) # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 14d24dd4ef7a6fe85766369c62a08e255cf476d9..ad2cbb1031a6734750debe39ba369997010cbc14 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -271,11 +271,10 @@ def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype): assert_close_to_numpy_in_containers(actx, evaluate, args) - if sym_name in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]: - pytest.skip(f"'{sym_name}' not supported on scalars") - - args = [randn(0, dtype)[()] for i in range(n_args)] - assert_close_to_numpy(actx, evaluate, args) + if sym_name not in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]: + # Scalar arguments are supported. + args = [randn(0, dtype)[()] for i in range(n_args)] + assert_close_to_numpy(actx, evaluate, args) @pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [ diff --git a/test/test_utils.py b/test/test_utils.py index 807d652d36154f7e0d9b8ce1b33e3de61986969f..3b74a42feadcfbbef6b89c7e81946144cc9b30a2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -27,7 +27,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import logging -from typing import Optional, cast +from typing import cast import numpy as np import pytest @@ -63,7 +63,7 @@ def test_dataclass_array_container() -> None: class ArrayContainerWithOptional: x: np.ndarray # Deliberately left as Optional to test compatibility. - y: Optional[np.ndarray] # noqa: UP007 + y: np.ndarray | None with pytest.raises(TypeError, match="Field 'y' union contains non-array"): # NOTE: cannot have wrapped annotations (here by `Optional`)