From 8b0eff23b88fed0254eb59a7ed1c036861ff2533 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Tue, 9 May 2023 10:08:26 -0500
Subject: [PATCH] rewrite matmul using einsum

---
 pytato/array.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/pytato/array.py b/pytato/array.py
index 09dff52..964ff35 100644
--- a/pytato/array.py
+++ b/pytato/array.py
@@ -1823,7 +1823,9 @@ def matmul(x1: Array, x2: Array) -> Array:
     elif x1.ndim == 1:
         return cast(Array, pt.dot(x1, x2))
     elif x2.ndim == 1:
-        return pt.sum(x1 * x2, axis=(x1.ndim - 1))
+        x1_indices = index_names[:x1.ndim]
+        return pt.einsum(f"{x1_indices}, {x1_indices[-1]} -> {x1_indices[:-1]}",
+                         x1, x2)
 
     stack_indices = index_names[:max(x1.ndim-2, x2.ndim-2)]
     x1_indices = stack_indices[len(stack_indices) - x1.ndim+2:] + "ij"
-- 
GitLab