diff --git a/test/test_statistics.py b/test/test_statistics.py index 87fa92a58fa6fa5801dad72f25e4b3328be49adc..a72b62af90050008f837e144f1f28d4a4de1c730 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -49,7 +49,7 @@ def test_op_counter_basic(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - op_map = lp.get_op_map(knl) + op_map = lp.get_op_map(knl, count_redundant_work=True) n = 512 m = 256 l = 128 @@ -74,7 +74,7 @@ def test_op_counter_reduction(): name="matmul_serial", assumptions="n,m,l >= 1") knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) - op_map = lp.get_op_map(knl) + op_map = lp.get_op_map(knl, count_redundant_work=True) n = 512 m = 256 l = 128 @@ -100,7 +100,7 @@ def test_op_counter_logic(): name="logic", assumptions="n,m,l >= 1") knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64)) - op_map = lp.get_op_map(knl) + op_map = lp.get_op_map(knl, count_redundant_work=True) n = 512 m = 256 l = 128 @@ -130,7 +130,7 @@ def test_op_counter_specialops(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - op_map = lp.get_op_map(knl) + op_map = lp.get_op_map(knl, count_redundant_work=True) n = 512 m = 256 l = 128 @@ -166,7 +166,7 @@ def test_op_counter_bitwise(): a=np.int32, b=np.int32, g=np.int64, h=np.int64)) - op_map = lp.get_op_map(knl) + op_map = lp.get_op_map(knl, count_redundant_work=True) n = 512 m = 256 l = 128 @@ -205,7 +205,7 @@ def test_op_counter_triangular_domain(): else: expect_fallback = False - op_map = lp.get_op_map(knl)[lp.Op(np.float64, 'mul')] + op_map = lp.get_op_map(knl, count_redundant_work=True)[lp.Op(np.float64, 'mul')] value_dict = dict(m=13, n=200) flops = op_map.eval_with_dict(value_dict) @@ -229,7 +229,7 @@ def test_mem_access_counter_basic(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - mem_map = lp.get_mem_access_map(knl) + mem_map = lp.get_mem_access_map(knl, count_redundant_work=True) n = 512 m = 256 l = 128 @@ -269,7 +269,7 @@ def test_mem_access_counter_reduction(): name="matmul", assumptions="n,m,l >= 1") knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) - mem_map = lp.get_mem_access_map(knl) + mem_map = lp.get_mem_access_map(knl, count_redundant_work=True) n = 512 m = 256 l = 128 @@ -307,7 +307,7 @@ def test_mem_access_counter_logic(): name="logic", assumptions="n,m,l >= 1") knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64)) - mem_map = lp.get_mem_access_map(knl) + mem_map = lp.get_mem_access_map(knl, count_redundant_work=True) n = 512 m = 256 l = 128 @@ -343,7 +343,7 @@ def test_mem_access_counter_specialops(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - mem_map = lp.get_mem_access_map(knl) + mem_map = lp.get_mem_access_map(knl, count_redundant_work=True) n = 512 m = 256 l = 128 @@ -395,7 +395,7 @@ def test_mem_access_counter_bitwise(): a=np.int32, b=np.int32, g=np.int32, h=np.int32)) - mem_map = lp.get_mem_access_map(knl) + mem_map = lp.get_mem_access_map(knl, count_redundant_work=True) n = 512 m = 256 l = 128 @@ -441,7 +441,7 @@ def test_mem_access_counter_mixed(): knl = lp.split_iname(knl, "j", bsize) knl = lp.tag_inames(knl, {"j_inner": "l.0", "j_outer": "g.0"}) - mem_map = lp.get_mem_access_map(knl) # noqa + mem_map = lp.get_mem_access_map(knl, count_redundant_work=True) # noqa n = 512 m = 256 l = 128 @@ -494,7 +494,7 @@ def test_mem_access_counter_nonconsec(): knl = lp.split_iname(knl, "i", 16) knl = lp.tag_inames(knl, {"i_inner": "l.0", "i_outer": "g.0"}) - mem_map = lp.get_mem_access_map(knl) # noqa + mem_map = lp.get_mem_access_map(knl, count_redundant_work=True) # noqa n = 512 m = 256 l = 128 @@ -545,7 +545,7 @@ def test_mem_access_counter_consec(): a=np.float32, b=np.float32, g=np.float64, h=np.float64)) knl = lp.tag_inames(knl, {"k": "l.0", "i": "g.0", "j": "g.1"}) - mem_map = lp.get_mem_access_map(knl) + mem_map = lp.get_mem_access_map(knl, count_redundant_work=True) n = 512 m = 256 l = 128 @@ -652,7 +652,7 @@ def test_all_counters_parallel_matmul(): assert sync_map["kernel_launch"].eval_with_dict(params) == 1 assert sync_map["barrier_local"].eval_with_dict(params) == 2*m/bsize - op_map = lp.get_op_map(knl) + op_map = lp.get_op_map(knl, count_redundant_work=True) f32mul = op_map[ lp.Op(np.float32, 'mul') ].eval_with_dict(params) @@ -668,7 +668,7 @@ def test_all_counters_parallel_matmul(): assert f32mul+f32add == n*m*l*2 - op_map = lp.get_mem_access_map(knl) + op_map = lp.get_mem_access_map(knl, count_redundant_work=True) f32s1lb = op_map[lp.MemAccess('global', np.float32, stride=1, direction='load', variable='b') @@ -686,7 +686,8 @@ def test_all_counters_parallel_matmul(): assert f32coal == n*l - local_mem_map = lp.get_mem_access_map(knl).filter_by(mtype=['local']) + local_mem_map = lp.get_mem_access_map(knl, + count_redundant_work=True).filter_by(mtype=['local']) local_mem_l = local_mem_map[lp.MemAccess('local', np.dtype(np.float32), direction='load') ].eval_with_dict(params) @@ -744,7 +745,7 @@ def test_summations_and_filters(): l = 128 params = {'n': n, 'm': m, 'l': l} - mem_map = lp.get_mem_access_map(knl) + mem_map = lp.get_mem_access_map(knl, count_redundant_work=True) loads_a = mem_map.filter_by(direction=['load'], variable=['a'] ).eval_and_sum(params) @@ -770,7 +771,7 @@ def test_summations_and_filters(): assert f32lall == 3*n*m*l assert f64lall == 2*n*m - op_map = lp.get_op_map(knl) + op_map = lp.get_op_map(knl, count_redundant_work=True) #for k, v in op_map.items(): # print(type(k), "\n", k.name, k.dtype, type(k.dtype), " :\n", v)