diff --git a/test/test_statistics.py b/test/test_statistics.py index 5e363f13594ee8e4cf170faa232b0783cca9d018..82bde8f5c461cd8a55b0a2bbe49c633f11170737 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -515,7 +515,7 @@ def test_mem_access_counter_nonconsec(): stride=Variable('m')*Variable('l'), direction='load', variable='b') ].eval_with_dict(params) - assert f64nonconsec == 2*n*m + assert f64nonconsec == 2*n*m*l assert f32nonconsec == 3*n*m*l f64nonconsec = mem_map[lp.MemAccess('global', np.float64, @@ -563,7 +563,7 @@ def test_mem_access_counter_consec(): f32consec += mem_map[lp.MemAccess('global', np.dtype(np.float32), stride=1, direction='load', variable='b') ].eval_with_dict(params) - assert f64consec == 2*n*m + assert f64consec == 2*n*m*l assert f32consec == 3*n*m*l f64consec = mem_map[lp.MemAccess('global', np.float64, @@ -628,6 +628,7 @@ def test_barrier_counter_barriers(): def test_all_counters_parallel_matmul(): + bsize = 16 knl = lp.make_kernel( "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", [ @@ -635,9 +636,9 @@ def test_all_counters_parallel_matmul(): ], name="matmul", assumptions="n,m,l >= 1") knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) - knl = lp.split_iname(knl, "i", 16, outer_tag="g.0", inner_tag="l.1") - knl = lp.split_iname(knl, "j", 16, outer_tag="g.1", inner_tag="l.0") - knl = lp.split_iname(knl, "k", 16) + knl = lp.split_iname(knl, "i", bsize, outer_tag="g.0", inner_tag="l.1") + knl = lp.split_iname(knl, "j", bsize, outer_tag="g.1", inner_tag="l.0") + knl = lp.split_iname(knl, "k", bsize) knl = lp.add_prefetch(knl, "a", ["k_inner", "i_inner"]) knl = lp.add_prefetch(knl, "b", ["j_inner", "k_inner"]) @@ -649,7 +650,7 @@ def test_all_counters_parallel_matmul(): sync_map = lp.get_synchronization_map(knl) assert len(sync_map) == 2 assert sync_map["kernel_launch"].eval_with_dict(params) == 1 - assert sync_map["barrier_local"].eval_with_dict(params) == 2*m/16 + assert sync_map["barrier_local"].eval_with_dict(params) == 2*m/bsize op_map = lp.get_op_map(knl) f32mul = op_map[ @@ -669,14 +670,15 @@ def test_all_counters_parallel_matmul(): op_map = lp.get_mem_access_map(knl) - f32coal = op_map[lp.MemAccess('global', np.float32, + f32s1lb = op_map[lp.MemAccess('global', np.float32, stride=1, direction='load', variable='b') ].eval_with_dict(params) - f32coal += op_map[lp.MemAccess('global', np.float32, - stride=1, direction='load', variable='a') - ].eval_with_dict(params) + f32s1la = op_map[lp.MemAccess('global', np.float32, + stride=1, direction='load', variable='a') + ].eval_with_dict(params) - assert f32coal == n*m+m*l + assert f32s1lb == (m/bsize)*n*m + assert f32s1la == (m/bsize)*m*l f32coal = op_map[lp.MemAccess('global', np.float32, stride=1, direction='store', variable='c')