diff --git a/arraycontext/impl/pytato.py b/arraycontext/impl/pytato.py index a7cea85dd0c18454b24d4df7a310c9ad9a7f8c84..a8748cab47a0b2dbb1bd3751812678c472bb5a9d 100644 --- a/arraycontext/impl/pytato.py +++ b/arraycontext/impl/pytato.py @@ -421,9 +421,23 @@ class PytatoArrayContext(ArrayContext): # Sorry, not capable. return array - def einsum(self, spec, *args, tagged=()): + def einsum(self, spec, *args, arg_names=None, tagged=()): + if arg_names is not None: + from warnings import warn + warn("'arg_names' don't bear any significance in PytatoArrayContext.", + stacklevel=2) + import pytato as pt - return pt.einsum(spec, *args) + import pyopencl.array as cla + + def preprocess_arg(arg): + if isinstance(arg, cla.Array): + return self.thaw(arg) + else: + assert isinstance(arg, pt.Array) + return arg + + return pt.einsum(spec, *(preprocess_arg(arg) for arg in args)) # }}}