diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index dbc725f70a5e27f80e9903706bfe19b75e4bfbcc..89c4e885b34546d81c15f6d7de426241de27ebce 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -38,6 +38,8 @@ import loopy as lp from pytools.tag import Tag from arraycontext.context import ArrayContext +from arraycontext.container.traversal import ( + rec_map_array_container, with_array_context) class NumpyArrayContext(ArrayContext): @@ -91,10 +93,16 @@ class NumpyArrayContext(ArrayContext): return result def freeze(self, array): - return array + def _freeze(ary): + return ary + + return with_array_context(rec_map_array_container(_freeze, array), actx=None) def thaw(self, array): - return array + def _thaw(ary): + return ary + + return with_array_context(rec_map_array_container(_thaw, array), actx=self) # }}}