diff --git a/test/test_transform.py b/test/test_transform.py index ed00ebd1bbd49d91676d53192f0434ddfd97ed4d..ae0a577c1be3b20b68f1befe2f98d7a0e63c4211 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -553,6 +553,8 @@ def test_remove_work(ctx_factory): ], assumptions="n>=1") + knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32)) + knl = lp.split_iname(knl, "i", 16, outer_tag="g.1", inner_tag="l.1") knl = lp.split_iname(knl, "j", 16, outer_tag="g.0", inner_tag="l.0") knl = lp.add_prefetch(knl, "a", ["i_inner", "j_inner"], @@ -562,7 +564,8 @@ def test_remove_work(ctx_factory): from loopy.transform.instruction import remove_work knl = remove_work(knl) - lp.auto_test_vs_ref(None, ctx, knl, print_ref_code=False) + lp.auto_test_vs_ref(knl, ctx, None, print_ref_code=False, + parameters=dict(n=512)) if __name__ == "__main__":