diff --git a/test/test_target.py b/test/test_target.py index f29fbeecc1b7dddd7abd2ce6a970580df3e96b93..b4cb509de72ad1d99c067c77699d6ca04a7c6063 100644 --- a/test/test_target.py +++ b/test/test_target.py @@ -842,6 +842,30 @@ def test_to_complex_casts(ctx_factory): cl.Program(ctx, code).build() +def test_cl_vectorize_ternary(ctx_factory): + knl = lp.make_kernel( + "{ [i]: 0<=i0") + + rng = np.random.default_rng(seed=12) + a = rng.normal(size=(16, 4)) + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + _evt, (result,) = knl(queue, a=a, n=a.size) + + result_ref = np.where(a < 0, a*3, np.sin(a)) + assert np.allclose(result, result_ref) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])