From 1251f96ed4ae1daa33149870654a6d3bfb189180 Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Wed, 22 Jun 2022 13:47:40 -0500
Subject: [PATCH] add test_c_instruction_in_callee

---
 test/test_callables.py | 37 +++++++++++++++++++++++++++++++++++++
 1 file changed, 37 insertions(+)

diff --git a/test/test_callables.py b/test/test_callables.py
index e092ae0da..8ac29f39a 100644
--- a/test/test_callables.py
+++ b/test/test_callables.py
@@ -1305,6 +1305,43 @@ def test_inlining_does_not_lose_preambles(ctx_factory, inline):
 # }}}
 
 
+@pytest.mark.parametrize("inline", [True, False])
+def test_c_instruction_in_callee(ctx_factory, inline):
+    from loopy.symbolic import parse
+
+    ctx = ctx_factory()
+    cq = cl.CommandQueue(ctx)
+
+    n = np.random.randint(3, 8)
+
+    knl = lp.make_function(
+        "{[i]: 0<=i<10}",
+        [lp.CInstruction(iname_exprs=("i", "i"),
+                         code="break;",
+                         predicates={parse("i >= n")},
+                         id="break",),
+         lp.Assignment("out_callee", "i", depends_on=frozenset(["break"]))
+         ],
+        [lp.ValueArg("n", dtype="int32"), ...],
+        name="circuit_breaker")
+
+    t_unit = lp.make_kernel(
+        "{ : }",
+        """
+        []: result[0] = circuit_breaker(N)
+        """,
+        [lp.ValueArg("N", dtype="int32"), ...],)
+
+    t_unit = lp.merge([t_unit, knl])
+
+    if inline:
+        t_unit = lp.inline_callable_kernel(t_unit, "circuit_breaker")
+
+    _, (out,) = t_unit(cq, N=n)
+
+    assert out.get() == (n-1)
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab