From 8b0eff23b88fed0254eb59a7ed1c036861ff2533 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Tue, 9 May 2023 10:08:26 -0500 Subject: [PATCH] rewrite matmul using einsum --- pytato/array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index 09dff52..964ff35 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1823,7 +1823,9 @@ def matmul(x1: Array, x2: Array) -> Array: elif x1.ndim == 1: return cast(Array, pt.dot(x1, x2)) elif x2.ndim == 1: - return pt.sum(x1 * x2, axis=(x1.ndim - 1)) + x1_indices = index_names[:x1.ndim] + return pt.einsum(f"{x1_indices}, {x1_indices[-1]} -> {x1_indices[:-1]}", + x1, x2) stack_indices = index_names[:max(x1.ndim-2, x2.ndim-2)] x1_indices = stack_indices[len(stack_indices) - x1.ndim+2:] + "ij" -- GitLab