Skip to content
Snippets Groups Projects
Commit 37f83d5f authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Merge branch 'stats-with-unused-hw-axes' of...

Merge branch 'stats-with-unused-hw-axes' of ssh://gitlab.tiker.net/inducer/loopy into stats-with-unused-hw-axes
parents 7d5d2ed1 b1686701
No related branches found
No related tags found
1 merge request!121Stats: take into account unused hw axes in run count, refactor code
Pipeline #
...@@ -515,7 +515,7 @@ def test_mem_access_counter_nonconsec(): ...@@ -515,7 +515,7 @@ def test_mem_access_counter_nonconsec():
stride=Variable('m')*Variable('l'), stride=Variable('m')*Variable('l'),
direction='load', variable='b') direction='load', variable='b')
].eval_with_dict(params) ].eval_with_dict(params)
assert f64nonconsec == 2*n*m assert f64nonconsec == 2*n*m*l
assert f32nonconsec == 3*n*m*l assert f32nonconsec == 3*n*m*l
f64nonconsec = mem_map[lp.MemAccess('global', np.float64, f64nonconsec = mem_map[lp.MemAccess('global', np.float64,
...@@ -563,7 +563,7 @@ def test_mem_access_counter_consec(): ...@@ -563,7 +563,7 @@ def test_mem_access_counter_consec():
f32consec += mem_map[lp.MemAccess('global', np.dtype(np.float32), f32consec += mem_map[lp.MemAccess('global', np.dtype(np.float32),
stride=1, direction='load', variable='b') stride=1, direction='load', variable='b')
].eval_with_dict(params) ].eval_with_dict(params)
assert f64consec == 2*n*m assert f64consec == 2*n*m*l
assert f32consec == 3*n*m*l assert f32consec == 3*n*m*l
f64consec = mem_map[lp.MemAccess('global', np.float64, f64consec = mem_map[lp.MemAccess('global', np.float64,
...@@ -628,6 +628,7 @@ def test_barrier_counter_barriers(): ...@@ -628,6 +628,7 @@ def test_barrier_counter_barriers():
def test_all_counters_parallel_matmul(): def test_all_counters_parallel_matmul():
bsize = 16
knl = lp.make_kernel( knl = lp.make_kernel(
"{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
[ [
...@@ -635,9 +636,9 @@ def test_all_counters_parallel_matmul(): ...@@ -635,9 +636,9 @@ def test_all_counters_parallel_matmul():
], ],
name="matmul", assumptions="n,m,l >= 1") name="matmul", assumptions="n,m,l >= 1")
knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
knl = lp.split_iname(knl, "i", 16, outer_tag="g.0", inner_tag="l.1") knl = lp.split_iname(knl, "i", bsize, outer_tag="g.0", inner_tag="l.1")
knl = lp.split_iname(knl, "j", 16, outer_tag="g.1", inner_tag="l.0") knl = lp.split_iname(knl, "j", bsize, outer_tag="g.1", inner_tag="l.0")
knl = lp.split_iname(knl, "k", 16) knl = lp.split_iname(knl, "k", bsize)
knl = lp.add_prefetch(knl, "a", ["k_inner", "i_inner"]) knl = lp.add_prefetch(knl, "a", ["k_inner", "i_inner"])
knl = lp.add_prefetch(knl, "b", ["j_inner", "k_inner"]) knl = lp.add_prefetch(knl, "b", ["j_inner", "k_inner"])
...@@ -649,7 +650,7 @@ def test_all_counters_parallel_matmul(): ...@@ -649,7 +650,7 @@ def test_all_counters_parallel_matmul():
sync_map = lp.get_synchronization_map(knl) sync_map = lp.get_synchronization_map(knl)
assert len(sync_map) == 2 assert len(sync_map) == 2
assert sync_map["kernel_launch"].eval_with_dict(params) == 1 assert sync_map["kernel_launch"].eval_with_dict(params) == 1
assert sync_map["barrier_local"].eval_with_dict(params) == 2*m/16 assert sync_map["barrier_local"].eval_with_dict(params) == 2*m/bsize
op_map = lp.get_op_map(knl) op_map = lp.get_op_map(knl)
f32mul = op_map[ f32mul = op_map[
...@@ -669,14 +670,15 @@ def test_all_counters_parallel_matmul(): ...@@ -669,14 +670,15 @@ def test_all_counters_parallel_matmul():
op_map = lp.get_mem_access_map(knl) op_map = lp.get_mem_access_map(knl)
f32coal = op_map[lp.MemAccess('global', np.float32, f32s1lb = op_map[lp.MemAccess('global', np.float32,
stride=1, direction='load', variable='b') stride=1, direction='load', variable='b')
].eval_with_dict(params) ].eval_with_dict(params)
f32coal += op_map[lp.MemAccess('global', np.float32, f32s1la = op_map[lp.MemAccess('global', np.float32,
stride=1, direction='load', variable='a') stride=1, direction='load', variable='a')
].eval_with_dict(params) ].eval_with_dict(params)
assert f32coal == n*m+m*l assert f32s1lb == (m/bsize)*n*m
assert f32s1la == (m/bsize)*m*l
f32coal = op_map[lp.MemAccess('global', np.float32, f32coal = op_map[lp.MemAccess('global', np.float32,
stride=1, direction='store', variable='c') stride=1, direction='store', variable='c')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment