Skip to content
Snippets Groups Projects
Commit 286011de authored by Kaushik Kulkarni's avatar Kaushik Kulkarni
Browse files

adds test_actx_compile

parent 666be4a4
No related branches found
No related tags found
No related merge requests found
......@@ -754,6 +754,40 @@ def test_norm_ord_none(actx_factory, ndim):
np.testing.assert_allclose(norm_a, norm_a_ref)
# {{{ test_actx_compile helpers
@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True)
@dataclass_array_container
@dataclass(frozen=True)
class Velocity2D:
u: np.ndarray
v: np.ndarray
def scale_and_to_speed(alpha, vel):
actx = vel.array_context
scaled_vel = alpha * vel
return actx.np.sqrt(scaled_vel.u**2 + scaled_vel.v**2)
# }}}
def test_actx_compile(actx_factory):
actx = actx_factory()
compiled_rhs = actx.compile(scale_and_to_speed)
v_x = np.random.rand(10)
v_y = np.random.rand(10)
vel = actx.from_numpy(Velocity2D(v_x, v_y))
scaled_speed = compiled_rhs(3.14, vel)
np.testing.assert_allclose(actx.to_numpy(scaled_speed),
3.14 * np.sqrt(v_x ** 2 + v_y ** 2))
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment