From efd2d687d00df4749c270dc835fdc839d1042b25 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Sat, 11 Jun 2022 12:26:18 +0300 Subject: [PATCH] port deprecated uses of freeze and thaw --- arraycontext/container/traversal.py | 2 +- test/test_arraycontext.py | 17 +++++++++-------- test/test_pytato_arraycontext.py | 16 ++++++++-------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index a53cd5d..84bd546 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -883,7 +883,7 @@ def to_numpy(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: return rec_map_array_container(_to_numpy_with_check, # do a freeze first, if 'actx' supports # container-wide freezes - thaw(freeze(ary, actx), actx)) + actx.thaw(actx.freeze(ary))) # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 7b88698..0848987 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -31,8 +31,7 @@ from pytools.obj_array import make_obj_array from arraycontext import ( ArrayContext, dataclass_array_container, with_container_arithmetic, - serialize_container, deserialize_container, - freeze, thaw, with_array_context, + serialize_container, deserialize_container, with_array_context, FirstAxisIsElementsTag, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, @@ -955,22 +954,22 @@ def test_container_freeze_thaw(actx_factory): assert get_container_context_recursively_opt(mat_of_dofs) is actx for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: - frozen_ary = freeze(ary) - thawed_ary = thaw(frozen_ary, actx) - frozen_ary = freeze(thawed_ary) + frozen_ary = actx.freeze(ary) + thawed_ary = actx.thaw(frozen_ary) + frozen_ary = actx.freeze(thawed_ary) assert get_container_context_recursively_opt(frozen_ary) is None assert get_container_context_recursively_opt(thawed_ary) is actx actx2 = actx.clone() - ary_dof_frozen = freeze(ary_dof) + ary_dof_frozen = actx.freeze(ary_dof) with pytest.raises(ValueError) as exc_info: ary_dof + ary_dof_frozen assert "frozen" in str(exc_info.value) - ary_dof_2 = thaw(freeze(ary_dof), actx2) + ary_dof_2 = actx2.thaw(actx.freeze(ary_dof)) with pytest.raises(ValueError): ary_dof + ary_dof_2 @@ -1434,7 +1433,9 @@ def test_actx_compile_on_pure_array_return(actx_factory): return 2 * x actx = actx_factory() - ones = actx.zeros(shape=(10, 4), dtype=np.float64) + 1 + ones = actx.thaw(actx.freeze( + actx.zeros(shape=(10, 4), dtype=np.float64) + 1 + )) np.testing.assert_allclose(actx.to_numpy(_twice(ones)), actx.to_numpy(actx.compile(_twice)(ones))) diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index b71f795..1d40ae3 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -22,7 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from arraycontext import (freeze, thaw, PytatoPyOpenCLArrayContext) +from arraycontext import PytatoPyOpenCLArrayContext from arraycontext import pytest_generate_tests_for_array_contexts from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory from pytools.tag import Tag @@ -83,13 +83,13 @@ def test_tags_preserved_after_freeze(actx_factory): rng = default_rng() actx = actx_factory() - foo = thaw(freeze(actx - .from_numpy(rng.random((10, 4))) - .tagged(FooTag()) - .with_tagged_axis(0, BarTag()) - .with_tagged_axis(1, BazTag()), - actx), - actx) + foo = actx.thaw(actx.freeze( + actx.from_numpy(rng.random((10, 4))) + .tagged(FooTag()) + .with_tagged_axis(0, BarTag()) + .with_tagged_axis(1, BazTag()) + )) + assert foo.tags_of_type(FooTag) assert foo.axes[0].tags_of_type(BarTag) assert foo.axes[1].tags_of_type(BazTag) -- GitLab