diff --git a/pytato/array.py b/pytato/array.py index 09dff523d099d23491db76fa8028324c2e6cc0c7..964ff351f29a07cbe16f93385725d913a67f9275 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1823,7 +1823,9 @@ def matmul(x1: Array, x2: Array) -> Array: elif x1.ndim == 1: return cast(Array, pt.dot(x1, x2)) elif x2.ndim == 1: - return pt.sum(x1 * x2, axis=(x1.ndim - 1)) + x1_indices = index_names[:x1.ndim] + return pt.einsum(f"{x1_indices}, {x1_indices[-1]} -> {x1_indices[:-1]}", + x1, x2) stack_indices = index_names[:max(x1.ndim-2, x2.ndim-2)] x1_indices = stack_indices[len(stack_indices) - x1.ndim+2:] + "ij"