diff --git a/test/test_loopy.py b/test/test_loopy.py index de64c820ed20f22701be4df70e38a0e4d6401ae6..7dc080e2d1d0c0a81151fe558f63661768c9ef8b 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2291,6 +2291,28 @@ def test_is_expression_equal(): assert is_expression_equal((x+y)**2, x**2 + 2*x*y + y**2) +def test_collect_common_factors(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[i,j,k]: 0<=i,j out_tmp = 0 {id=out_init,inames=i} + out_tmp = out_tmp + alpha[i]*a[i,j]*b1[j] {id=out_up1,dep=out_init} + out_tmp = out_tmp + alpha[i]*a[j,i]*b2[j] {id=out_up2,dep=out_init} + out[i] = out_tmp {dep=out_up1:out_up2} + """) + knl = lp.add_and_infer_dtypes(knl, + dict(a=np.float32, alpha=np.float32, b1=np.float32, b2=np.float32)) + + ref_knl = knl + + knl = lp.split_iname(knl, "i", 256, outer_tag="g.0", inner_tag="l.0") + knl = lp.collect_common_factors_on_increment(knl, "out_tmp") + + lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=13)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])