diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index a895eb4d525145bca47663838bcd08cab166e2c4..216bd1e1e59166dacc8773784b3264142e4f0aa4 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -74,10 +74,11 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
 
     def __init__(self, queue, allocator=None):
         import pytato as pt
+        import pyopencl.array as cla
         super().__init__()
         self.queue = queue
         self.allocator = allocator
-        self.array_types = (pt.Array, )
+        self.array_types = (pt.Array, cla.Array)
         self._freeze_prg_cache = {}
         self._dag_transform_cache = {}
 
@@ -114,6 +115,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         return cl_array.get(queue=self.queue)
 
     def call_loopy(self, program, **kwargs):
+        import pytato as pt
         from pytato.scalar_expr import SCALAR_CLASSES
         from pytato.loopy import call_loopy
         from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
@@ -125,7 +127,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         processed_kwargs = {}
 
         for kw, arg in sorted(kwargs.items()):
-            if isinstance(arg, self.array_types + SCALAR_CLASSES):
+            if isinstance(arg, (pt.Array,) + SCALAR_CLASSES):
                 pass
             elif isinstance(arg, TaggableCLArray):
                 arg = self.thaw(arg)