Skip to content
Snippets Groups Projects
Commit 54d0030e authored by Andreas Klöckner's avatar Andreas Klöckner Committed by Andreas Klöckner
Browse files

Specialize freeze_thaw for pytato actx

parent 83873066
No related branches found
No related tags found
No related merge requests found
...@@ -116,6 +116,10 @@ def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]: ...@@ -116,6 +116,10 @@ def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]:
# }}} # }}}
class _NotOnlyDataWrappers(Exception): # noqa: N818
pass
# {{{ _BasePytatoArrayContext # {{{ _BasePytatoArrayContext
class _BasePytatoArrayContext(ArrayContext, abc.ABC): class _BasePytatoArrayContext(ArrayContext, abc.ABC):
...@@ -585,6 +589,24 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): ...@@ -585,6 +589,24 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)), self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)),
actx=self) actx=self)
def freeze_thaw(self, array):
import pytato as pt
import arraycontext.impl.pyopencl.taggable_cl_array as tga
def _ft(ary):
if isinstance(ary, (pt.DataWrapper, tga.TaggableCLArray)):
return ary
else:
raise _NotOnlyDataWrappers()
try:
return with_array_context(
self._rec_map_container(_ft, array),
actx=self)
except _NotOnlyDataWrappers:
return super().freeze_thaw(array)
def tag(self, tags: ToTagSetConvertible, array): def tag(self, tags: ToTagSetConvertible, array):
def _tag(ary): def _tag(ary):
return ary.tagged(_preprocess_array_tags(tags)) return ary.tagged(_preprocess_array_tags(tags))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment