From 1f5d92a820231ca032e0f01eb0640ea4bf76e814 Mon Sep 17 00:00:00 2001 From: Thomas Gibson <gibsonthomas1120@hotmail.com> Date: Fri, 21 May 2021 22:25:25 -0500 Subject: [PATCH] Add array manipulation unit tests --- test/test_arraycontext.py | 58 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index ac31d2a..752f69b 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -178,6 +178,43 @@ def test_array_context_np_workalike(actx_factory): # }}} +# {{{ Array manipulations + +def test_actx_stack(actx_factory): + actx = actx_factory() + + ndofs = 5000 + args = [np.random.randn(ndofs) for i in range(10)] + ref_result = np.stack(args) + + # {{{ test cl.Arrays + + actx_args = [actx.from_numpy(arg) for arg in args] + actx_result = actx.to_numpy(actx.np.stack(actx_args)) + + assert np.allclose(actx_result, ref_result) + + # }}} + + # {{{ test DOFArrays + + actx_args = [DOFArray(actx, (arg,)) for arg in actx_args] + actx_result = actx.to_numpy(actx.np.stack(actx_args)[0]) + + assert np.allclose(actx_result, ref_result) + + # }}} + + # {{{ test object arrays + + obj_array_args = [make_obj_array([arg]) for arg in actx_args] + obj_array_result = actx.to_numpy(actx.np.stack(obj_array_args)[0][0]) + + assert np.allclose(obj_array_result, ref_result) + + # }}} + + def test_actx_concatenate(actx_factory): actx = actx_factory() @@ -195,6 +232,27 @@ def test_actx_concatenate(actx_factory): # }}} +def test_actx_reshape(actx_factory): + actx = actx_factory() + + numpy_ary = np.random.randn(2, 3) + actx_ary = actx.from_numpy(numpy_ary) + + assert np.allclose(actx.to_numpy(actx.np.reshape(actx_ary, (3, 2))), + np.reshape(numpy_ary, (3, 2))) + + assert np.allclose(actx.to_numpy(actx.np.reshape(actx_ary, (3, -1))), + np.reshape(numpy_ary, (3, -1))) + + assert np.allclose(actx.to_numpy(actx.np.reshape(actx_ary, (6,))), + np.reshape(numpy_ary, (6,))) + + assert np.allclose(actx.to_numpy(actx.np.reshape(actx_ary, -1)), + np.reshape(numpy_ary, -1)) + +# }}} + + def test_dof_array_arithmetic_same_as_numpy(actx_factory): actx = actx_factory() -- GitLab