From 54d0030e99c9de4c8be0a6471a9bdaca265a64eb Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 21 Mar 2025 14:00:02 -0500
Subject: [PATCH] Specialize freeze_thaw for pytato actx

---
 arraycontext/impl/pytato/__init__.py | 22 ++++++++++++++++++++++
 1 file changed, 22 insertions(+)

diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index f7f7be8..60aadba 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -116,6 +116,10 @@ def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]:
 # }}}
 
 
+class _NotOnlyDataWrappers(Exception):  # noqa: N818
+    pass
+
+
 # {{{ _BasePytatoArrayContext
 
 class _BasePytatoArrayContext(ArrayContext, abc.ABC):
@@ -585,6 +589,24 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
             self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)),
             actx=self)
 
+    def freeze_thaw(self, array):
+        import pytato as pt
+
+        import arraycontext.impl.pyopencl.taggable_cl_array as tga
+
+        def _ft(ary):
+            if isinstance(ary, (pt.DataWrapper, tga.TaggableCLArray)):
+                return ary
+            else:
+                raise _NotOnlyDataWrappers()
+
+        try:
+            return with_array_context(
+                self._rec_map_container(_ft, array),
+                actx=self)
+        except _NotOnlyDataWrappers:
+            return super().freeze_thaw(array)
+
     def tag(self, tags: ToTagSetConvertible, array):
         def _tag(ary):
             return ary.tagged(_preprocess_array_tags(tags))
-- 
GitLab