diff --git a/test/test_transform.py b/test/test_transform.py index 770e43617f05a499293de9435122baa8484ebfb2..ccbbd3da0c61f491ff43a32fe2a7a19f8ec55a9d 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -107,7 +107,6 @@ def test_to_batched(ctx_factory): def test_to_batched_temp(ctx_factory): ctx = ctx_factory() - queue = cl.CommandQueue(ctx) knl = lp.make_kernel( ''' { [i,j]: 0<=i,j<n } ''', @@ -115,16 +114,28 @@ def test_to_batched_temp(ctx_factory): out[i] = sum(j, cnst*a[i,j]*x[j])''', [lp.TemporaryVariable( "cnst", - dtype=np.float64, + dtype=np.float32, shape=(), scope=lp.temp_var_scope.PRIVATE), '...']) + knl = lp.add_and_infer_dtypes(knl, dict(out=np.float32, + x=np.float32, + a=np.float32)) + ref_knl = lp.make_kernel( + ''' { [i,j]: 0<=i,j<n } ''', + '''out[i] = sum(j, 2.0*a[i,j]*x[j])''') + ref_knl = lp.add_and_infer_dtypes(ref_knl, dict(out=np.float32, + x=np.float32, + a=np.float32)) bknl = lp.to_batched(knl, "nbatches", "out,x") + bref_knl = lp.to_batched(ref_knl, "nbatches", "out,x") a = np.random.randn(5, 5) x = np.random.randn(7, 5) - bknl(queue, a=a, x=x) + lp.auto_test_vs_ref( + bref_knl, ctx, bknl, + parameters=dict(a=a, x=x, n=5, nbatches=7)) def test_add_barrier(ctx_factory):