From 54d0030e99c9de4c8be0a6471a9bdaca265a64eb Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Fri, 21 Mar 2025 14:00:02 -0500 Subject: [PATCH] Specialize freeze_thaw for pytato actx --- arraycontext/impl/pytato/__init__.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index f7f7be8..60aadba 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -116,6 +116,10 @@ def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]: # }}} +class _NotOnlyDataWrappers(Exception): # noqa: N818 + pass + + # {{{ _BasePytatoArrayContext class _BasePytatoArrayContext(ArrayContext, abc.ABC): @@ -585,6 +589,24 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)), actx=self) + def freeze_thaw(self, array): + import pytato as pt + + import arraycontext.impl.pyopencl.taggable_cl_array as tga + + def _ft(ary): + if isinstance(ary, (pt.DataWrapper, tga.TaggableCLArray)): + return ary + else: + raise _NotOnlyDataWrappers() + + try: + return with_array_context( + self._rec_map_container(_ft, array), + actx=self) + except _NotOnlyDataWrappers: + return super().freeze_thaw(array) + def tag(self, tags: ToTagSetConvertible, array): def _tag(ary): return ary.tagged(_preprocess_array_tags(tags)) -- GitLab