diff --git a/test/test_transform.py b/test/test_transform.py index e50605b46672f8e9c1817431f1577742b1f6fb4c..770e43617f05a499293de9435122baa8484ebfb2 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -105,6 +105,28 @@ def test_to_batched(ctx_factory): bknl(queue, a=a, x=x) +def test_to_batched_temp(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + knl = lp.make_kernel( + ''' { [i,j]: 0<=i,j<n } ''', + ''' cnst = 2.0 + out[i] = sum(j, cnst*a[i,j]*x[j])''', + [lp.TemporaryVariable( + "cnst", + dtype=np.float64, + shape=(), + scope=lp.temp_var_scope.PRIVATE), '...']) + + bknl = lp.to_batched(knl, "nbatches", "out,x") + + a = np.random.randn(5, 5) + x = np.random.randn(7, 5) + + bknl(queue, a=a, x=x) + + def test_add_barrier(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx)