diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index a895eb4d525145bca47663838bcd08cab166e2c4..216bd1e1e59166dacc8773784b3264142e4f0aa4 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -74,10 +74,11 @@ class PytatoPyOpenCLArrayContext(ArrayContext): def __init__(self, queue, allocator=None): import pytato as pt + import pyopencl.array as cla super().__init__() self.queue = queue self.allocator = allocator - self.array_types = (pt.Array, ) + self.array_types = (pt.Array, cla.Array) self._freeze_prg_cache = {} self._dag_transform_cache = {} @@ -114,6 +115,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): return cl_array.get(queue=self.queue) def call_loopy(self, program, **kwargs): + import pytato as pt from pytato.scalar_expr import SCALAR_CLASSES from pytato.loopy import call_loopy from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray @@ -125,7 +127,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): processed_kwargs = {} for kw, arg in sorted(kwargs.items()): - if isinstance(arg, self.array_types + SCALAR_CLASSES): + if isinstance(arg, (pt.Array,) + SCALAR_CLASSES): pass elif isinstance(arg, TaggableCLArray): arg = self.thaw(arg)