diff --git a/pytato/__init__.py b/pytato/__init__.py index 114eedff0e914771aab751235b81e67852d2a3b7..2676c323bfa4d4d260718ddef0ca7d9553332e3f 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -50,6 +50,8 @@ from pytato.array import ( sum, amax, amin, prod, real, imag, + dot, vdot, + ) from pytato.loopy import LoopyCall @@ -90,4 +92,6 @@ __all__ = ( "sum", "amax", "amin", "prod", "real", "imag", + "dot", "vdot", + ) diff --git a/pytato/array.py b/pytato/array.py index f39109bf16ce9888949cfc7346e442ece2dca72e..7eda1b0869b57ce6640fc5ff78009b4d039f46b5 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -96,6 +96,8 @@ These functions generally follow the interface of the corresponding functions in .. autofunction:: amax .. autofunction:: prod .. autofunction:: einsum +.. autofunction:: dot +.. autofunction:: vdot .. currentmodule:: pytato.array @@ -2445,4 +2447,54 @@ def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: # }}} + +def dot(a: ArrayOrScalar, b: ArrayOrScalar) -> ArrayOrScalar: + """ + For 1-dimensional arrays *a* and *b* computes their inner product. See + :func:`numpy.dot` for behavior in the case when *a* and *b* aren't + single-dimensional arrays. + """ + import pytato as pt + + if isinstance(a, SCALAR_CLASSES) or isinstance(b, SCALAR_CLASSES): + # type-ignored because Number * bool is undefined + return a * b # type: ignore + + assert isinstance(a, Array) + assert isinstance(b, Array) + + if a.ndim == b.ndim == 1: + return pt.sum(a*b) + elif a.ndim == b.ndim == 2: + return a @ b + elif a.ndim == 0 or b.ndim == 0: + return a * b + elif b.ndim == 1: + return pt.sum(a * b, axis=(a.ndim - 1)) + else: + idx_stream = (chr(i) for i in range(ord("i"), ord("z"))) + idx_gen: Callable[[], str] = lambda: next(idx_stream) # noqa: E731 + a_indices = "".join(idx_gen() for _ in range(a.ndim)) + b_indices = "".join(idx_gen() for _ in range(b.ndim)) + # reduce over second-to-last axis of *b* and last axis of *a* + b_indices = b_indices[:-2] + a_indices[-1] + b_indices[-1] + result_indices = a_indices[:-1] + b_indices[:-2] + b_indices[-1] + return pt.einsum(f"{a_indices}, {b_indices} -> {result_indices}", a, b) + + +def vdot(a: Array, b: Array) -> ArrayOrScalar: + """ + Returns the dot-product of conjugate of *a* with *b*. If the input + arguments are multi-dimensional arrays, they are ravel-ed first and then + their *vdot* is computed. + """ + import pytato as pt + + if isinstance(a, Array) and a.ndim > 1: + a = a.reshape(-1) + if isinstance(b, Array) and b.ndim > 1: + b = b.reshape(-1) + + return pt.dot(pt.conj(a), b) + # vim: foldmethod=marker