From fbb75928b98f3537eee7c47a8074053d020583dc Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Sat, 12 Jun 2021 16:13:35 -0500
Subject: [PATCH] PytatoArrayContext.einsum: thaw arguments before einsumming
 them

---
 arraycontext/impl/pytato.py | 18 ++++++++++++++++--
 1 file changed, 16 insertions(+), 2 deletions(-)

diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py
index a7cea85..a8748ca 100644
--- a/arraycontext/impl/pytato.py
+++ b/arraycontext/impl/pytato.py
@@ -421,9 +421,23 @@ class PytatoArrayContext(ArrayContext):
         # Sorry, not capable.
         return array
 
-    def einsum(self, spec, *args, tagged=()):
+    def einsum(self, spec, *args, arg_names=None, tagged=()):
+        if arg_names is not None:
+            from warnings import warn
+            warn("'arg_names' don't bear any significance in PytatoArrayContext.",
+                 stacklevel=2)
+
         import pytato as pt
-        return pt.einsum(spec, *args)
+        import pyopencl.array as cla
+
+        def preprocess_arg(arg):
+            if isinstance(arg, cla.Array):
+                return self.thaw(arg)
+            else:
+                assert isinstance(arg, pt.Array)
+                return arg
+
+        return pt.einsum(spec, *(preprocess_arg(arg) for arg in args))
 
 
 # }}}
-- 
GitLab