From 3864505d68872387a54cdc577e46d2e75292205b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 1 Dec 2015 19:23:54 -0600 Subject: [PATCH] Add test for distributive law transform --- test/test_loopy.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/test_loopy.py b/test/test_loopy.py index de64c820e..7dc080e2d 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2291,6 +2291,28 @@ def test_is_expression_equal(): assert is_expression_equal((x+y)**2, x**2 + 2*x*y + y**2) +def test_collect_common_factors(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel( + "{[i,j,k]: 0<=i,j out_tmp = 0 {id=out_init,inames=i} + out_tmp = out_tmp + alpha[i]*a[i,j]*b1[j] {id=out_up1,dep=out_init} + out_tmp = out_tmp + alpha[i]*a[j,i]*b2[j] {id=out_up2,dep=out_init} + out[i] = out_tmp {dep=out_up1:out_up2} + """) + knl = lp.add_and_infer_dtypes(knl, + dict(a=np.float32, alpha=np.float32, b1=np.float32, b2=np.float32)) + + ref_knl = knl + + knl = lp.split_iname(knl, "i", 256, outer_tag="g.0", inner_tag="l.0") + knl = lp.collect_common_factors_on_increment(knl, "out_tmp") + + lp.auto_test_vs_ref(ref_knl, ctx, knl, parameters=dict(n=13)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab