From 2efd3dae4f5ac4dcd938d9d6840dc55d44612baf Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 21 Jul 2021 18:47:38 -0500 Subject: [PATCH] implements pt.(v?)dot --- pytato/__init__.py | 4 ++++ pytato/array.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/pytato/__init__.py b/pytato/__init__.py index 114eedf..2676c32 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -50,6 +50,8 @@ from pytato.array import ( sum, amax, amin, prod, real, imag, + dot, vdot, + ) from pytato.loopy import LoopyCall @@ -90,4 +92,6 @@ __all__ = ( "sum", "amax", "amin", "prod", "real", "imag", + "dot", "vdot", + ) diff --git a/pytato/array.py b/pytato/array.py index f39109b..7eda1b0 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -96,6 +96,8 @@ These functions generally follow the interface of the corresponding functions in .. autofunction:: amax .. autofunction:: prod .. autofunction:: einsum +.. autofunction:: dot +.. autofunction:: vdot .. currentmodule:: pytato.array @@ -2445,4 +2447,54 @@ def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: # }}} + +def dot(a: ArrayOrScalar, b: ArrayOrScalar) -> ArrayOrScalar: + """ + For 1-dimensional arrays *a* and *b* computes their inner product. See + :func:`numpy.dot` for behavior in the case when *a* and *b* aren't + single-dimensional arrays. + """ + import pytato as pt + + if isinstance(a, SCALAR_CLASSES) or isinstance(b, SCALAR_CLASSES): + # type-ignored because Number * bool is undefined + return a * b # type: ignore + + assert isinstance(a, Array) + assert isinstance(b, Array) + + if a.ndim == b.ndim == 1: + return pt.sum(a*b) + elif a.ndim == b.ndim == 2: + return a @ b + elif a.ndim == 0 or b.ndim == 0: + return a * b + elif b.ndim == 1: + return pt.sum(a * b, axis=(a.ndim - 1)) + else: + idx_stream = (chr(i) for i in range(ord("i"), ord("z"))) + idx_gen: Callable[[], str] = lambda: next(idx_stream) # noqa: E731 + a_indices = "".join(idx_gen() for _ in range(a.ndim)) + b_indices = "".join(idx_gen() for _ in range(b.ndim)) + # reduce over second-to-last axis of *b* and last axis of *a* + b_indices = b_indices[:-2] + a_indices[-1] + b_indices[-1] + result_indices = a_indices[:-1] + b_indices[:-2] + b_indices[-1] + return pt.einsum(f"{a_indices}, {b_indices} -> {result_indices}", a, b) + + +def vdot(a: Array, b: Array) -> ArrayOrScalar: + """ + Returns the dot-product of conjugate of *a* with *b*. If the input + arguments are multi-dimensional arrays, they are ravel-ed first and then + their *vdot* is computed. + """ + import pytato as pt + + if isinstance(a, Array) and a.ndim > 1: + a = a.reshape(-1) + if isinstance(b, Array) and b.ndim > 1: + b = b.reshape(-1) + + return pt.dot(pt.conj(a), b) + # vim: foldmethod=marker -- GitLab