diff --git a/test/test_callables.py b/test/test_callables.py index e092ae0da5b2aeadff6f72f6d8208e3ab08a81ee..8ac29f39ad5f99a42094c9ff7cf44ba373e3d66e 100644 --- a/test/test_callables.py +++ b/test/test_callables.py @@ -1305,6 +1305,43 @@ def test_inlining_does_not_lose_preambles(ctx_factory, inline): # }}} +@pytest.mark.parametrize("inline", [True, False]) +def test_c_instruction_in_callee(ctx_factory, inline): + from loopy.symbolic import parse + + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + n = np.random.randint(3, 8) + + knl = lp.make_function( + "{[i]: 0<=i<10}", + [lp.CInstruction(iname_exprs=("i", "i"), + code="break;", + predicates={parse("i >= n")}, + id="break",), + lp.Assignment("out_callee", "i", depends_on=frozenset(["break"])) + ], + [lp.ValueArg("n", dtype="int32"), ...], + name="circuit_breaker") + + t_unit = lp.make_kernel( + "{ : }", + """ + []: result[0] = circuit_breaker(N) + """, + [lp.ValueArg("N", dtype="int32"), ...],) + + t_unit = lp.merge([t_unit, knl]) + + if inline: + t_unit = lp.inline_callable_kernel(t_unit, "circuit_breaker") + + _, (out,) = t_unit(cq, N=n) + + assert out.get() == (n-1) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])