From 1251f96ed4ae1daa33149870654a6d3bfb189180 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Wed, 22 Jun 2022 13:47:40 -0500 Subject: [PATCH] add test_c_instruction_in_callee --- test/test_callables.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/test_callables.py b/test/test_callables.py index e092ae0da..8ac29f39a 100644 --- a/test/test_callables.py +++ b/test/test_callables.py @@ -1305,6 +1305,43 @@ def test_inlining_does_not_lose_preambles(ctx_factory, inline): # }}} +@pytest.mark.parametrize("inline", [True, False]) +def test_c_instruction_in_callee(ctx_factory, inline): + from loopy.symbolic import parse + + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + n = np.random.randint(3, 8) + + knl = lp.make_function( + "{[i]: 0<=i<10}", + [lp.CInstruction(iname_exprs=("i", "i"), + code="break;", + predicates={parse("i >= n")}, + id="break",), + lp.Assignment("out_callee", "i", depends_on=frozenset(["break"])) + ], + [lp.ValueArg("n", dtype="int32"), ...], + name="circuit_breaker") + + t_unit = lp.make_kernel( + "{ : }", + """ + []: result[0] = circuit_breaker(N) + """, + [lp.ValueArg("N", dtype="int32"), ...],) + + t_unit = lp.merge([t_unit, knl]) + + if inline: + t_unit = lp.inline_callable_kernel(t_unit, "circuit_breaker") + + _, (out,) = t_unit(cq, N=n) + + assert out.get() == (n-1) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab