diff --git a/test/test_statistics.py b/test/test_statistics.py index 82bde8f5c461cd8a55b0a2bbe49c633f11170737..87fa92a58fa6fa5801dad72f25e4b3328be49adc 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -437,8 +437,8 @@ def test_mem_access_counter_mixed(): knl = lp.add_and_infer_dtypes(knl, dict( a=np.float32, b=np.float32, g=np.float64, h=np.float64, x=np.float32)) - threads = 16 - knl = lp.split_iname(knl, "j", threads) + bsize = 16 + knl = lp.split_iname(knl, "j", bsize) knl = lp.tag_inames(knl, {"j_inner": "l.0", "j_outer": "g.0"}) mem_map = lp.get_mem_access_map(knl) # noqa @@ -463,8 +463,8 @@ def test_mem_access_counter_mixed(): stride=Variable('m'), direction='load', variable='b') ].eval_with_dict(params) - assert f64uniform == 2*n*m - assert f32uniform == n*m*l/threads + assert f64uniform == 2*n*m*l/bsize + assert f32uniform == n*m*l/bsize assert f32nonconsec == 3*n*m*l f64uniform = mem_map[lp.MemAccess('global', np.float64, @@ -474,7 +474,7 @@ def test_mem_access_counter_mixed(): stride=Variable('m'), direction='store', variable='c') ].eval_with_dict(params) - assert f64uniform == n*m + assert f64uniform == n*m*l/bsize assert f32nonconsec == n*m*l @@ -515,7 +515,7 @@ def test_mem_access_counter_nonconsec(): stride=Variable('m')*Variable('l'), direction='load', variable='b') ].eval_with_dict(params) - assert f64nonconsec == 2*n*m*l + assert f64nonconsec == 2*n*m assert f32nonconsec == 3*n*m*l f64nonconsec = mem_map[lp.MemAccess('global', np.float64, @@ -572,7 +572,7 @@ def test_mem_access_counter_consec(): f32consec = mem_map[lp.MemAccess('global', np.float32, stride=1, direction='store', variable='c') ].eval_with_dict(params) - assert f64consec == n*m + assert f64consec == n*m*l assert f32consec == n*m*l @@ -677,8 +677,8 @@ def test_all_counters_parallel_matmul(): stride=1, direction='load', variable='a') ].eval_with_dict(params) - assert f32s1lb == (m/bsize)*n*m - assert f32s1la == (m/bsize)*m*l + assert f32s1lb == n*m*l/bsize + assert f32s1la == n*m*l/bsize f32coal = op_map[lp.MemAccess('global', np.float32, stride=1, direction='store', variable='c')