diff --git a/test/test_transform.py b/test/test_transform.py
index ed00ebd1bbd49d91676d53192f0434ddfd97ed4d..ae0a577c1be3b20b68f1befe2f98d7a0e63c4211 100644
--- a/test/test_transform.py
+++ b/test/test_transform.py
@@ -553,6 +553,8 @@ def test_remove_work(ctx_factory):
                 ],
             assumptions="n>=1")
 
+    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32))
+
     knl = lp.split_iname(knl, "i", 16, outer_tag="g.1", inner_tag="l.1")
     knl = lp.split_iname(knl, "j", 16, outer_tag="g.0", inner_tag="l.0")
     knl = lp.add_prefetch(knl, "a", ["i_inner", "j_inner"],
@@ -562,7 +564,8 @@ def test_remove_work(ctx_factory):
     from loopy.transform.instruction import remove_work
     knl = remove_work(knl)
 
-    lp.auto_test_vs_ref(None, ctx, knl, print_ref_code=False)
+    lp.auto_test_vs_ref(knl, ctx, None, print_ref_code=False,
+            parameters=dict(n=512))
 
 
 if __name__ == "__main__":