From fbb75928b98f3537eee7c47a8074053d020583dc Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Sat, 12 Jun 2021 16:13:35 -0500 Subject: [PATCH] PytatoArrayContext.einsum: thaw arguments before einsumming them --- arraycontext/impl/pytato.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index a7cea85..a8748ca 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -421,9 +421,23 @@ class PytatoArrayContext(ArrayContext): # Sorry, not capable. return array - def einsum(self, spec, *args, tagged=()): + def einsum(self, spec, *args, arg_names=None, tagged=()): + if arg_names is not None: + from warnings import warn + warn("'arg_names' don't bear any significance in PytatoArrayContext.", + stacklevel=2) + import pytato as pt - return pt.einsum(spec, *args) + import pyopencl.array as cla + + def preprocess_arg(arg): + if isinstance(arg, cla.Array): + return self.thaw(arg) + else: + assert isinstance(arg, pt.Array) + return arg + + return pt.einsum(spec, *(preprocess_arg(arg) for arg in args)) # }}} -- GitLab