From cc274f61f25386ef149646197172b6383d9ffb00 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Sun, 1 May 2022 17:05:58 -0500 Subject: [PATCH] adds cla.Array to PytatoPyOpenCLArrayContext.array_types --- arraycontext/impl/pytato/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index a895eb4..216bd1e 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -74,10 +74,11 @@ class PytatoPyOpenCLArrayContext(ArrayContext): def __init__(self, queue, allocator=None): import pytato as pt + import pyopencl.array as cla super().__init__() self.queue = queue self.allocator = allocator - self.array_types = (pt.Array, ) + self.array_types = (pt.Array, cla.Array) self._freeze_prg_cache = {} self._dag_transform_cache = {} @@ -114,6 +115,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): return cl_array.get(queue=self.queue) def call_loopy(self, program, **kwargs): + import pytato as pt from pytato.scalar_expr import SCALAR_CLASSES from pytato.loopy import call_loopy from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray @@ -125,7 +127,7 @@ class PytatoPyOpenCLArrayContext(ArrayContext): processed_kwargs = {} for kw, arg in sorted(kwargs.items()): - if isinstance(arg, self.array_types + SCALAR_CLASSES): + if isinstance(arg, (pt.Array,) + SCALAR_CLASSES): pass elif isinstance(arg, TaggableCLArray): arg = self.thaw(arg) -- GitLab