diff --git a/test/test_transform.py b/test/test_transform.py
index 770e43617f05a499293de9435122baa8484ebfb2..ccbbd3da0c61f491ff43a32fe2a7a19f8ec55a9d 100644
--- a/test/test_transform.py
+++ b/test/test_transform.py
@@ -107,7 +107,6 @@ def test_to_batched(ctx_factory):
 
 def test_to_batched_temp(ctx_factory):
     ctx = ctx_factory()
-    queue = cl.CommandQueue(ctx)
 
     knl = lp.make_kernel(
          ''' { [i,j]: 0<=i,j<n } ''',
@@ -115,16 +114,28 @@ def test_to_batched_temp(ctx_factory):
          out[i] = sum(j, cnst*a[i,j]*x[j])''',
          [lp.TemporaryVariable(
              "cnst",
-             dtype=np.float64,
+             dtype=np.float32,
              shape=(),
              scope=lp.temp_var_scope.PRIVATE), '...'])
+    knl = lp.add_and_infer_dtypes(knl, dict(out=np.float32,
+                                            x=np.float32,
+                                            a=np.float32))
+    ref_knl = lp.make_kernel(
+         ''' { [i,j]: 0<=i,j<n } ''',
+         '''out[i] = sum(j, 2.0*a[i,j]*x[j])''')
+    ref_knl = lp.add_and_infer_dtypes(ref_knl, dict(out=np.float32,
+                                                    x=np.float32,
+                                                    a=np.float32))
 
     bknl = lp.to_batched(knl, "nbatches", "out,x")
+    bref_knl = lp.to_batched(ref_knl, "nbatches", "out,x")
 
     a = np.random.randn(5, 5)
     x = np.random.randn(7, 5)
 
-    bknl(queue, a=a, x=x)
+    lp.auto_test_vs_ref(
+            bref_knl, ctx, bknl,
+            parameters=dict(a=a, x=x, n=5, nbatches=7))
 
 
 def test_add_barrier(ctx_factory):