From 286011de1381cd043bf777d86b4423c41834b899 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Tue, 15 Jun 2021 15:35:50 -0500
Subject: [PATCH] adds test_actx_compile

---
 test/test_arraycontext.py | 34 ++++++++++++++++++++++++++++++++++
 1 file changed, 34 insertions(+)

diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 19b427e..fcf17c0 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -754,6 +754,40 @@ def test_norm_ord_none(actx_factory, ndim):
     np.testing.assert_allclose(norm_a, norm_a_ref)
 
 
+# {{{ test_actx_compile helpers
+
+@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True)
+@dataclass_array_container
+@dataclass(frozen=True)
+class Velocity2D:
+    u: np.ndarray
+    v: np.ndarray
+
+
+def scale_and_to_speed(alpha, vel):
+    actx = vel.array_context
+    scaled_vel = alpha * vel
+    return actx.np.sqrt(scaled_vel.u**2 + scaled_vel.v**2)
+
+# }}}
+
+
+def test_actx_compile(actx_factory):
+    actx = actx_factory()
+
+    compiled_rhs = actx.compile(scale_and_to_speed)
+
+    v_x = np.random.rand(10)
+    v_y = np.random.rand(10)
+
+    vel = actx.from_numpy(Velocity2D(v_x, v_y))
+
+    scaled_speed = compiled_rhs(3.14, vel)
+
+    np.testing.assert_allclose(actx.to_numpy(scaled_speed),
+                               3.14 * np.sqrt(v_x ** 2 + v_y ** 2))
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab