From efd2d687d00df4749c270dc835fdc839d1042b25 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Sat, 11 Jun 2022 12:26:18 +0300
Subject: [PATCH] port deprecated uses of freeze and thaw

---
 arraycontext/container/traversal.py |  2 +-
 test/test_arraycontext.py           | 17 +++++++++--------
 test/test_pytato_arraycontext.py    | 16 ++++++++--------
 3 files changed, 18 insertions(+), 17 deletions(-)

diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index a53cd5d..84bd546 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -883,7 +883,7 @@ def to_numpy(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
     return rec_map_array_container(_to_numpy_with_check,
                                    # do a freeze first, if 'actx' supports
                                    # container-wide freezes
-                                   thaw(freeze(ary, actx), actx))
+                                   actx.thaw(actx.freeze(ary)))
 
 # }}}
 
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 7b88698..0848987 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -31,8 +31,7 @@ from pytools.obj_array import make_obj_array
 from arraycontext import (
         ArrayContext,
         dataclass_array_container, with_container_arithmetic,
-        serialize_container, deserialize_container,
-        freeze, thaw, with_array_context,
+        serialize_container, deserialize_container, with_array_context,
         FirstAxisIsElementsTag,
         PyOpenCLArrayContext,
         PytatoPyOpenCLArrayContext,
@@ -955,22 +954,22 @@ def test_container_freeze_thaw(actx_factory):
     assert get_container_context_recursively_opt(mat_of_dofs) is actx
 
     for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
-        frozen_ary = freeze(ary)
-        thawed_ary = thaw(frozen_ary, actx)
-        frozen_ary = freeze(thawed_ary)
+        frozen_ary = actx.freeze(ary)
+        thawed_ary = actx.thaw(frozen_ary)
+        frozen_ary = actx.freeze(thawed_ary)
 
         assert get_container_context_recursively_opt(frozen_ary) is None
         assert get_container_context_recursively_opt(thawed_ary) is actx
 
     actx2 = actx.clone()
 
-    ary_dof_frozen = freeze(ary_dof)
+    ary_dof_frozen = actx.freeze(ary_dof)
     with pytest.raises(ValueError) as exc_info:
         ary_dof + ary_dof_frozen
 
     assert "frozen" in str(exc_info.value)
 
-    ary_dof_2 = thaw(freeze(ary_dof), actx2)
+    ary_dof_2 = actx2.thaw(actx.freeze(ary_dof))
 
     with pytest.raises(ValueError):
         ary_dof + ary_dof_2
@@ -1434,7 +1433,9 @@ def test_actx_compile_on_pure_array_return(actx_factory):
         return 2 * x
 
     actx = actx_factory()
-    ones = actx.zeros(shape=(10, 4), dtype=np.float64) + 1
+    ones = actx.thaw(actx.freeze(
+        actx.zeros(shape=(10, 4), dtype=np.float64) + 1
+        ))
     np.testing.assert_allclose(actx.to_numpy(_twice(ones)),
                                actx.to_numpy(actx.compile(_twice)(ones)))
 
diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py
index b71f795..1d40ae3 100644
--- a/test/test_pytato_arraycontext.py
+++ b/test/test_pytato_arraycontext.py
@@ -22,7 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from arraycontext import (freeze, thaw, PytatoPyOpenCLArrayContext)
+from arraycontext import PytatoPyOpenCLArrayContext
 from arraycontext import pytest_generate_tests_for_array_contexts
 from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory
 from pytools.tag import Tag
@@ -83,13 +83,13 @@ def test_tags_preserved_after_freeze(actx_factory):
     rng = default_rng()
 
     actx = actx_factory()
-    foo = thaw(freeze(actx
-                      .from_numpy(rng.random((10, 4)))
-                      .tagged(FooTag())
-                      .with_tagged_axis(0, BarTag())
-                      .with_tagged_axis(1, BazTag()),
-                      actx),
-               actx)
+    foo = actx.thaw(actx.freeze(
+        actx.from_numpy(rng.random((10, 4)))
+        .tagged(FooTag())
+        .with_tagged_axis(0, BarTag())
+        .with_tagged_axis(1, BazTag())
+        ))
+
     assert foo.tags_of_type(FooTag)
     assert foo.axes[0].tags_of_type(BarTag)
     assert foo.axes[1].tags_of_type(BazTag)
-- 
GitLab