diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index f7f7be8db312a03426c16bbbcf70e3f7ce9c94af..60aadba99beaa2bb849f0097cb5648fe4c858ec8 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -116,6 +116,10 @@ def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]: # }}} +class _NotOnlyDataWrappers(Exception): # noqa: N818 + pass + + # {{{ _BasePytatoArrayContext class _BasePytatoArrayContext(ArrayContext, abc.ABC): @@ -585,6 +589,24 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)), actx=self) + def freeze_thaw(self, array): + import pytato as pt + + import arraycontext.impl.pyopencl.taggable_cl_array as tga + + def _ft(ary): + if isinstance(ary, (pt.DataWrapper, tga.TaggableCLArray)): + return ary + else: + raise _NotOnlyDataWrappers() + + try: + return with_array_context( + self._rec_map_container(_ft, array), + actx=self) + except _NotOnlyDataWrappers: + return super().freeze_thaw(array) + def tag(self, tags: ToTagSetConvertible, array): def _tag(ary): return ary.tagged(_preprocess_array_tags(tags))