From c4a49e4c4736d8e56757c9df98f6d38813681b0d Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Tue, 9 Jan 2018 19:20:56 -0600 Subject: [PATCH] Added test for the private temps in `to_batched` --- test/test_transform.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/test_transform.py b/test/test_transform.py index e50605b46..770e43617 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -105,6 +105,28 @@ def test_to_batched(ctx_factory): bknl(queue, a=a, x=x) +def test_to_batched_temp(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + knl = lp.make_kernel( + ''' { [i,j]: 0<=i,j<n } ''', + ''' cnst = 2.0 + out[i] = sum(j, cnst*a[i,j]*x[j])''', + [lp.TemporaryVariable( + "cnst", + dtype=np.float64, + shape=(), + scope=lp.temp_var_scope.PRIVATE), '...']) + + bknl = lp.to_batched(knl, "nbatches", "out,x") + + a = np.random.randn(5, 5) + x = np.random.randn(7, 5) + + bknl(queue, a=a, x=x) + + def test_add_barrier(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) -- GitLab