From 286011de1381cd043bf777d86b4423c41834b899 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Tue, 15 Jun 2021 15:35:50 -0500 Subject: [PATCH] adds test_actx_compile --- test/test_arraycontext.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 19b427e..fcf17c0 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -754,6 +754,40 @@ def test_norm_ord_none(actx_factory, ndim): np.testing.assert_allclose(norm_a, norm_a_ref) +# {{{ test_actx_compile helpers + +@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True) +@dataclass_array_container +@dataclass(frozen=True) +class Velocity2D: + u: np.ndarray + v: np.ndarray + + +def scale_and_to_speed(alpha, vel): + actx = vel.array_context + scaled_vel = alpha * vel + return actx.np.sqrt(scaled_vel.u**2 + scaled_vel.v**2) + +# }}} + + +def test_actx_compile(actx_factory): + actx = actx_factory() + + compiled_rhs = actx.compile(scale_and_to_speed) + + v_x = np.random.rand(10) + v_y = np.random.rand(10) + + vel = actx.from_numpy(Velocity2D(v_x, v_y)) + + scaled_speed = compiled_rhs(3.14, vel) + + np.testing.assert_allclose(actx.to_numpy(scaled_speed), + 3.14 * np.sqrt(v_x ** 2 + v_y ** 2)) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab