diff --git a/loopy/statistics.py b/loopy/statistics.py index cee28b24f8bdd44392f41f437c6042f3aa08ce2c..2df3093d1b58babd35c61efdbcee20bba243c643 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -715,7 +715,8 @@ class ExpressionOpCounter(CounterBase): return ToCountMap( {Op(dtype=self.type_inf(expr), name='func:'+str(expr.function), - count_granularity=CountGranularity.WORKITEM): 1} + #count_granularity=CountGranularity.WORKITEM): 1} + count_granularity=CountGranularity.SUBGROUP): 1} ) + self.rec(expr.parameters) def map_subscript(self, expr): @@ -726,7 +727,8 @@ class ExpressionOpCounter(CounterBase): return ToCountMap( {Op(dtype=self.type_inf(expr), name='add', - count_granularity=CountGranularity.WORKITEM): + #count_granularity=CountGranularity.WORKITEM): + count_granularity=CountGranularity.SUBGROUP): len(expr.children)-1} ) + sum(self.rec(child) for child in expr.children) @@ -735,18 +737,21 @@ class ExpressionOpCounter(CounterBase): assert expr.children return sum(ToCountMap({Op(dtype=self.type_inf(expr), name='mul', - count_granularity=CountGranularity.WORKITEM): 1}) + #count_granularity=CountGranularity.WORKITEM): 1}) + count_granularity=CountGranularity.SUBGROUP): 1}) + self.rec(child) for child in expr.children if not is_zero(child + 1)) + \ ToCountMap({Op(dtype=self.type_inf(expr), name='mul', - count_granularity=CountGranularity.WORKITEM): -1}) + #count_granularity=CountGranularity.WORKITEM): -1}) + count_granularity=CountGranularity.SUBGROUP): -1}) def map_quotient(self, expr, *args): return ToCountMap({Op(dtype=self.type_inf(expr), name='div', - count_granularity=CountGranularity.WORKITEM): 1}) \ + #count_granularity=CountGranularity.WORKITEM): 1}) \ + count_granularity=CountGranularity.SUBGROUP): 1}) \ + self.rec(expr.numerator) \ + self.rec(expr.denominator) @@ -756,14 +761,16 @@ class ExpressionOpCounter(CounterBase): def map_power(self, expr): return ToCountMap({Op(dtype=self.type_inf(expr), name='pow', - count_granularity=CountGranularity.WORKITEM): 1}) \ + #count_granularity=CountGranularity.WORKITEM): 1}) \ + count_granularity=CountGranularity.SUBGROUP): 1}) \ + self.rec(expr.base) \ + self.rec(expr.exponent) def map_left_shift(self, expr): return ToCountMap({Op(dtype=self.type_inf(expr), name='shift', - count_granularity=CountGranularity.WORKITEM): 1}) \ + #count_granularity=CountGranularity.WORKITEM): 1}) \ + count_granularity=CountGranularity.SUBGROUP): 1}) \ + self.rec(expr.shiftee) \ + self.rec(expr.shift) @@ -772,13 +779,15 @@ class ExpressionOpCounter(CounterBase): def map_bitwise_not(self, expr): return ToCountMap({Op(dtype=self.type_inf(expr), name='bw', - count_granularity=CountGranularity.WORKITEM): 1}) \ + #count_granularity=CountGranularity.WORKITEM): 1}) \ + count_granularity=CountGranularity.SUBGROUP): 1}) \ + self.rec(expr.child) def map_bitwise_or(self, expr): return ToCountMap({Op(dtype=self.type_inf(expr), name='bw', - count_granularity=CountGranularity.WORKITEM): + #count_granularity=CountGranularity.WORKITEM): + count_granularity=CountGranularity.SUBGROUP): len(expr.children)-1}) \ + sum(self.rec(child) for child in expr.children) @@ -802,7 +811,8 @@ class ExpressionOpCounter(CounterBase): def map_min(self, expr): return ToCountMap({Op(dtype=self.type_inf(expr), name='maxmin', - count_granularity=CountGranularity.WORKITEM): + #count_granularity=CountGranularity.WORKITEM): + count_granularity=CountGranularity.SUBGROUP): len(expr.children)-1}) \ + sum(self.rec(child) for child in expr.children) @@ -1329,14 +1339,109 @@ def get_op_map(knl, numpy_types=True, count_redundant_work=False, knl = infer_unknown_types(knl, expect_completion=True) knl = preprocess_kernel(knl) + if not isinstance(subgroup_size, int): + # try to find subgroup_size + subgroup_size_guess = _find_subgroup_size_for_knl(knl) + + if subgroup_size is None: + if subgroup_size_guess is None: + # 'guess' was not passed and either no target device found + # or get_simd_group_size returned None + raise ValueError("No sub-group size passed, no target device found. " + "Either (1) pass integer value for subgroup_size, " + "(2) ensure that kernel.target is PyOpenClTarget " + "and kernel.target.device is set, or (3) pass " + "subgroup_size='guess' and hope for the best.") + else: + subgroup_size = subgroup_size_guess + + elif subgroup_size == 'guess': + if subgroup_size_guess is None: + # unable to get subgroup_size from device, so guess + subgroup_size = 32 + warn_with_kernel(knl, "get_op_map_guessing_subgroup_size", + "get_op_map: 'guess' sub-group size " + "passed, no target device found, wildly guessing " + "that sub-group size is %d." % (subgroup_size)) + else: + subgroup_size = subgroup_size_guess + else: + raise ValueError("Invalid value for subgroup_size: %s. subgroup_size " + "must be integer, 'guess', or, if you're feeling " + "lucky, None." % (subgroup_size)) + + # ------------------------------ + #class CacheHolder(object): + # pass + + #cache_holder = CacheHolder() + #from pytools import memoize_in + + #@memoize_in(cache_holder, "insn_count") + def get_insn_count(knl, insn, count_granularity=CountGranularity.WORKITEM): + + if count_granularity is None: + warn_with_kernel(knl, "get_insn_count_assumes_granularity", + "get_insn_count: No count granularity passed for " + "Op, assuming %s granularity." + % (CountGranularity.WORKITEM)) + count_granularity == CountGranularity.WORKITEM + + if count_granularity == CountGranularity.WORKITEM: + return count_insn_runs( + knl, insn, count_redundant_work=count_redundant_work, + disregard_local_axes=False) + + ct_disregard_local = count_insn_runs( + knl, insn, disregard_local_axes=True, + count_redundant_work=count_redundant_work) + + if count_granularity == CountGranularity.WORKGROUP: + return ct_disregard_local + elif count_granularity == CountGranularity.SUBGROUP: + # get the group size + from loopy.symbolic import aff_to_expr + _, local_size = knl.get_grid_size_upper_bounds() + workgroup_size = 1 + if local_size: + for size in local_size: + s = aff_to_expr(size) + if not isinstance(s, int): + raise LoopyError("Cannot count insn with %s granularity, " + "work-group size is not integer: %s" + % (CountGranularity.SUBGROUP, local_size)) + workgroup_size *= s + + warn_with_kernel(knl, "insn_count_subgroups_upper_bound", + "get_insn_count: when counting instruction %s with " + "count_granularity=%s, using upper bound for work-group size " + "(%d work-items) to compute sub-groups per work-group. When " + "multiple device programs present, actual sub-group count may be" + "lower." % (insn, CountGranularity.SUBGROUP, workgroup_size)) + + from pytools import div_ceil + return ct_disregard_local*div_ceil(workgroup_size, subgroup_size) + else: + # this should not happen since this is enforced in Op + raise ValueError("get_insn_count: count_granularity '%s' is" + "not allowed. count_granularity options: %s" + % (count_granularity, CountGranularity.ALL+[None])) + # ------------------------------ + op_map = ToCountMap() op_counter = ExpressionOpCounter(knl) for insn in knl.instructions: if isinstance(insn, (CallInstruction, CInstruction, Assignment)): ops = op_counter(insn.assignee) + op_counter(insn.expression) - op_map = op_map + ops*count_insn_runs( - knl, insn, - count_redundant_work=count_redundant_work) + #op_map = op_map + ops*count_insn_runs( + # knl, insn, + # count_redundant_work=count_redundant_work) + for key, val in six.iteritems(ops): + op_map = ( + op_map + + ToCountMap({key: val}) + * get_insn_count(knl, insn, key.count_granularity)) + elif isinstance(insn, (NoOpInstruction, BarrierInstruction)): pass else: diff --git a/test/test_statistics.py b/test/test_statistics.py index 79c5ec7da0971b534588be3bfcd58a9f5fc8405a..b5b55347c99df3abfd5301bc037df611a67126f4 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -39,6 +39,9 @@ from pymbolic.primitives import Variable from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa +SGS = 32 # Subgroup size + + def test_op_counter_basic(): knl = lp.make_kernel( @@ -54,21 +57,26 @@ def test_op_counter_basic(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(params) - f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(params) - f32div = op_map[lp.Op(np.float32, 'div', CG.WORKITEM)].eval_with_dict(params) - f64mul = op_map[lp.Op(np.dtype(np.float64), 'mul', CG.WORKITEM) + f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP)].eval_with_dict(params) + f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP)].eval_with_dict(params) + f32div = op_map[lp.Op(np.float32, 'div', CG.SUBGROUP)].eval_with_dict(params) + f64mul = op_map[lp.Op(np.dtype(np.float64), 'mul', CG.SUBGROUP) ].eval_with_dict(params) - i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.WORKITEM) + i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP) ].eval_with_dict(params) - assert f32add == f32mul == f32div == n*m*ell - assert f64mul == n*m - assert i32add == n*m*2 + # (count-per-sub-group)*n_subgroups + assert f32add == f32mul == f32div == n*m*ell*n_subgroups + assert f64mul == n*m*n_subgroups + assert i32add == n*m*2*n_subgroups def test_op_counter_reduction(): @@ -81,15 +89,20 @@ def test_op_counter_reduction(): name="matmul_serial", assumptions="n,m,ell >= 1") knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(params) - f32mul = op_map[lp.Op(np.dtype(np.float32), 'mul', CG.WORKITEM) + f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP)].eval_with_dict(params) + f32mul = op_map[lp.Op(np.dtype(np.float32), 'mul', CG.SUBGROUP) ].eval_with_dict(params) - assert f32add == f32mul == n*m*ell + # (count-per-sub-group)*n_subgroups + assert f32add == f32mul == n*m*ell*n_subgroups op_map_dtype = op_map.group_by('dtype') f32 = op_map_dtype[lp.Op(dtype=np.float32)].eval_with_dict(params) @@ -111,21 +124,26 @@ def test_op_counter_logic(): name="logic", assumptions="n,m,ell >= 1") knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64)) - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(params) - f64add = op_map[lp.Op(np.float64, 'add', CG.WORKITEM)].eval_with_dict(params) - f64div = op_map[lp.Op(np.dtype(np.float64), 'div', CG.WORKITEM) + f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP)].eval_with_dict(params) + f64add = op_map[lp.Op(np.float64, 'add', CG.SUBGROUP)].eval_with_dict(params) + f64div = op_map[lp.Op(np.dtype(np.float64), 'div', CG.SUBGROUP) ].eval_with_dict(params) - i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.WORKITEM) + i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP) ].eval_with_dict(params) - assert f32mul == n*m - assert f64div == 2*n*m # TODO why? - assert f64add == n*m - assert i32add == n*m + # (count-per-sub-group)*n_subgroups + assert f32mul == n*m*n_subgroups + assert f64div == 2*n*m*n_subgroups # TODO why? + assert f64add == n*m*n_subgroups + assert i32add == n*m*n_subgroups def test_op_counter_specialops(): @@ -143,27 +161,32 @@ def test_op_counter_specialops(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(params) - f32div = op_map[lp.Op(np.float32, 'div', CG.WORKITEM)].eval_with_dict(params) - f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(params) - f64pow = op_map[lp.Op(np.float64, 'pow', CG.WORKITEM)].eval_with_dict(params) - f64add = op_map[lp.Op(np.dtype(np.float64), 'add', CG.WORKITEM) + f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP)].eval_with_dict(params) + f32div = op_map[lp.Op(np.float32, 'div', CG.SUBGROUP)].eval_with_dict(params) + f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP)].eval_with_dict(params) + f64pow = op_map[lp.Op(np.float64, 'pow', CG.SUBGROUP)].eval_with_dict(params) + f64add = op_map[lp.Op(np.dtype(np.float64), 'add', CG.SUBGROUP) ].eval_with_dict(params) - i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.WORKITEM) + i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP) ].eval_with_dict(params) - f64rsq = op_map[lp.Op(np.dtype(np.float64), 'func:rsqrt', CG.WORKITEM) + f64rsq = op_map[lp.Op(np.dtype(np.float64), 'func:rsqrt', CG.SUBGROUP) ].eval_with_dict(params) - f64sin = op_map[lp.Op(np.dtype(np.float64), 'func:sin', CG.WORKITEM) + f64sin = op_map[lp.Op(np.dtype(np.float64), 'func:sin', CG.SUBGROUP) ].eval_with_dict(params) - assert f32div == 2*n*m*ell - assert f32mul == f32add == n*m*ell - assert f64add == 3*n*m - assert f64pow == i32add == f64rsq == f64sin == n*m + # (count-per-sub-group)*n_subgroups + assert f32div == 2*n*m*ell*n_subgroups + assert f32mul == f32add == n*m*ell*n_subgroups + assert f64add == 3*n*m*n_subgroups + assert f64pow == i32add == f64rsq == f64sin == n*m*n_subgroups def test_op_counter_bitwise(): @@ -183,26 +206,31 @@ def test_op_counter_bitwise(): a=np.int32, b=np.int32, g=np.int64, h=np.int64)) - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - i32add = op_map[lp.Op(np.int32, 'add', CG.WORKITEM)].eval_with_dict(params) - i32bw = op_map[lp.Op(np.int32, 'bw', CG.WORKITEM)].eval_with_dict(params) - i64bw = op_map[lp.Op(np.dtype(np.int64), 'bw', CG.WORKITEM) + i32add = op_map[lp.Op(np.int32, 'add', CG.SUBGROUP)].eval_with_dict(params) + i32bw = op_map[lp.Op(np.int32, 'bw', CG.SUBGROUP)].eval_with_dict(params) + i64bw = op_map[lp.Op(np.dtype(np.int64), 'bw', CG.SUBGROUP) ].eval_with_dict(params) - i64mul = op_map[lp.Op(np.dtype(np.int64), 'mul', CG.WORKITEM) + i64mul = op_map[lp.Op(np.dtype(np.int64), 'mul', CG.SUBGROUP) ].eval_with_dict(params) - i64add = op_map[lp.Op(np.dtype(np.int64), 'add', CG.WORKITEM) + i64add = op_map[lp.Op(np.dtype(np.int64), 'add', CG.SUBGROUP) ].eval_with_dict(params) - i64shift = op_map[lp.Op(np.dtype(np.int64), 'shift', CG.WORKITEM) + i64shift = op_map[lp.Op(np.dtype(np.int64), 'shift', CG.SUBGROUP) ].eval_with_dict(params) - assert i32add == n*m+n*m*ell - assert i32bw == 2*n*m*ell - assert i64bw == 2*n*m - assert i64add == i64mul == n*m - assert i64shift == 2*n*m + # (count-per-sub-group)*n_subgroups + assert i32add == n*m+n*m*ell*n_subgroups + assert i32bw == 2*n*m*ell*n_subgroups + assert i64bw == 2*n*m*n_subgroups + assert i64add == i64mul == n*m*n_subgroups + assert i64shift == 2*n*m*n_subgroups def test_op_counter_triangular_domain(): @@ -228,15 +256,21 @@ def test_op_counter_triangular_domain(): op_map = lp.get_op_map( knl, + subgroup_size=SGS, count_redundant_work=True - )[lp.Op(np.float64, 'mul', CG.WORKITEM)] + )[lp.Op(np.float64, 'mul', CG.SUBGROUP)] value_dict = dict(m=13, n=200) flops = op_map.eval_with_dict(value_dict) + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group + if expect_fallback: - assert flops == 144 + assert flops == 144*n_subgroups else: - assert flops == 78 + assert flops == 78*n_subgroups def test_mem_access_counter_basic(): @@ -254,10 +288,8 @@ def test_mem_access_counter_basic(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - subgroup_size = 32 - mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) n = 512 m = 256 @@ -266,7 +298,8 @@ def test_mem_access_counter_basic(): n_workgroups = 1 group_size = 1 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group f32l = mem_map[lp.MemAccess('global', np.float32, lid_strides={}, gid_strides={}, @@ -289,9 +322,9 @@ def test_mem_access_counter_basic(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32l == (3*n*m*ell)*n_workgroups*subgroups_per_group - assert f64l == (2*n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32l == (3*n*m*ell)*n_subgroups + assert f64l == (2*n*m)*n_subgroups f32s = mem_map[lp.MemAccess('global', np.dtype(np.float32), lid_strides={}, gid_strides={}, @@ -304,9 +337,9 @@ def test_mem_access_counter_basic(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32s == (n*m*ell)*n_workgroups*subgroups_per_group - assert f64s == (n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32s == (n*m*ell)*n_subgroups + assert f64s == (n*m)*n_subgroups def test_mem_access_counter_reduction(): @@ -320,10 +353,8 @@ def test_mem_access_counter_reduction(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) - subgroup_size = 32 - mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) n = 512 m = 256 ell = 128 @@ -331,7 +362,8 @@ def test_mem_access_counter_reduction(): n_workgroups = 1 group_size = 1 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group f32l = mem_map[lp.MemAccess('global', np.float32, lid_strides={}, gid_strides={}, @@ -344,8 +376,8 @@ def test_mem_access_counter_reduction(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32l == (2*n*m*ell)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32l == (2*n*m*ell)*n_subgroups f32s = mem_map[lp.MemAccess('global', np.dtype(np.float32), lid_strides={}, gid_strides={}, @@ -353,8 +385,8 @@ def test_mem_access_counter_reduction(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32s == (n*ell)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32s == (n*ell)*n_subgroups ld_bytes = mem_map.filter_by(mtype=['global'], direction=['load'] ).to_bytes().eval_and_sum(params) @@ -379,10 +411,8 @@ def test_mem_access_counter_logic(): knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64)) - subgroup_size = 32 - mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) n = 512 m = 256 ell = 128 @@ -390,7 +420,8 @@ def test_mem_access_counter_logic(): n_workgroups = 1 group_size = 1 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group reduced_map = mem_map.group_by('mtype', 'dtype', 'direction') @@ -404,10 +435,10 @@ def test_mem_access_counter_logic(): direction='store') ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32_g_l == (2*n*m)*n_workgroups*subgroups_per_group - assert f64_g_l == (n*m)*n_workgroups*subgroups_per_group - assert f64_g_s == (n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32_g_l == (2*n*m)*n_subgroups + assert f64_g_l == (n*m)*n_subgroups + assert f64_g_s == (n*m)*n_subgroups def test_mem_access_counter_specialops(): @@ -425,10 +456,8 @@ def test_mem_access_counter_specialops(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - subgroup_size = 32 - mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) n = 512 m = 256 ell = 128 @@ -436,7 +465,8 @@ def test_mem_access_counter_specialops(): n_workgroups = 1 group_size = 1 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group f32 = mem_map[lp.MemAccess('global', np.float32, lid_strides={}, gid_strides={}, @@ -459,9 +489,9 @@ def test_mem_access_counter_specialops(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32 == (2*n*m*ell)*n_workgroups*subgroups_per_group - assert f64 == (2*n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32 == (2*n*m*ell)*n_subgroups + assert f64 == (2*n*m)*n_subgroups f32 = mem_map[lp.MemAccess('global', np.float32, lid_strides={}, gid_strides={}, @@ -474,16 +504,16 @@ def test_mem_access_counter_specialops(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32 == (n*m*ell)*n_workgroups*subgroups_per_group - assert f64 == (n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32 == (n*m*ell)*n_subgroups + assert f64 == (n*m)*n_subgroups filtered_map = mem_map.filter_by(direction=['load'], variable=['a', 'g'], count_granularity=CG.SUBGROUP) tot = filtered_map.eval_and_sum(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert tot == (n*m*ell + n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert tot == (n*m*ell + n*m)*n_subgroups def test_mem_access_counter_bitwise(): @@ -503,10 +533,8 @@ def test_mem_access_counter_bitwise(): a=np.int32, b=np.int32, g=np.int32, h=np.int32)) - subgroup_size = 32 - mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) n = 512 m = 256 ell = 128 @@ -514,7 +542,8 @@ def test_mem_access_counter_bitwise(): n_workgroups = 1 group_size = 1 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group i32 = mem_map[lp.MemAccess('global', np.int32, lid_strides={}, gid_strides={}, @@ -537,8 +566,8 @@ def test_mem_access_counter_bitwise(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert i32 == (4*n*m+2*n*m*ell)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert i32 == (4*n*m+2*n*m*ell)*n_subgroups i32 = mem_map[lp.MemAccess('global', np.int32, lid_strides={}, gid_strides={}, @@ -551,8 +580,8 @@ def test_mem_access_counter_bitwise(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert i32 == (n*m+n*m*ell)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert i32 == (n*m+n*m*ell)*n_subgroups def test_mem_access_counter_mixed(): @@ -571,7 +600,6 @@ def test_mem_access_counter_mixed(): x=np.float32)) group_size_0 = 65 - subgroup_size = 32 knl = lp.split_iname(knl, "j", group_size_0) knl = lp.tag_inames(knl, {"j_inner": "l.0", "j_outer": "g.0"}) @@ -583,10 +611,11 @@ def test_mem_access_counter_mixed(): n_workgroups = div_ceil(ell, group_size_0) group_size = group_size_0 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) f64uniform = mem_map[lp.MemAccess('global', np.float64, lid_strides={}, gid_strides={}, direction='load', variable='g', @@ -617,9 +646,9 @@ def test_mem_access_counter_mixed(): count_granularity=CG.WORKITEM) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f64uniform == (2*n*m)*n_workgroups*subgroups_per_group - assert f32uniform == (m*n)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f64uniform == (2*n*m)*n_subgroups + assert f32uniform == (m*n)*n_subgroups expect_fallback = False import islpy as isl @@ -651,8 +680,8 @@ def test_mem_access_counter_mixed(): count_granularity=CG.WORKITEM) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f64uniform == m*n*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f64uniform == m*n*n_subgroups if expect_fallback: if ell < group_size_0: @@ -681,7 +710,7 @@ def test_mem_access_counter_nonconsec(): knl = lp.tag_inames(knl, {"i_inner": "l.0", "i_outer": "g.0"}) mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=32) # noqa + subgroup_size=SGS) # noqa n = 512 m = 256 ell = 128 @@ -939,30 +968,35 @@ def test_all_counters_parallel_matmul(): m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} + group_size = bsize*bsize + n_workgroups = div_ceil(n, bsize)*div_ceil(ell, bsize) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group sync_map = lp.get_synchronization_map(knl) assert len(sync_map) == 2 assert sync_map["kernel_launch"].eval_with_dict(params) == 1 assert sync_map["barrier_local"].eval_with_dict(params) == 2*m/bsize - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) f32mul = op_map[ - lp.Op(np.float32, 'mul', CG.WORKITEM) + lp.Op(np.float32, 'mul', CG.SUBGROUP) ].eval_with_dict(params) f32add = op_map[ - lp.Op(np.float32, 'add', CG.WORKITEM) + lp.Op(np.float32, 'add', CG.SUBGROUP) ].eval_with_dict(params) i32ops = op_map[ - lp.Op(np.int32, 'add', CG.WORKITEM) + lp.Op(np.int32, 'add', CG.SUBGROUP) ].eval_with_dict(params) i32ops += op_map[ - lp.Op(np.dtype(np.int32), 'mul', CG.WORKITEM) + lp.Op(np.dtype(np.int32), 'mul', CG.SUBGROUP) ].eval_with_dict(params) - assert f32mul+f32add == n*m*ell*2 + # (count-per-sub-group)*n_subgroups + assert f32mul+f32add == m*2*n_subgroups mem_access_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=32) + subgroup_size=SGS) f32s1lb = mem_access_map[lp.MemAccess('global', np.float32, lid_strides={0: 1, 1: Variable('ell')}, @@ -991,7 +1025,7 @@ def test_all_counters_parallel_matmul(): local_mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=32).filter_by(mtype=['local']) + subgroup_size=SGS).filter_by(mtype=['local']) local_mem_l = local_mem_map.filter_by(direction=['load'] ).eval_and_sum(params) @@ -1067,8 +1101,6 @@ def test_summations_and_filters(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - subgroup_size = 32 - n = 512 m = 256 ell = 128 @@ -1076,24 +1108,25 @@ def test_summations_and_filters(): n_workgroups = 1 group_size = 1 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) loads_a = mem_map.filter_by(direction=['load'], variable=['a'], count_granularity=[CG.SUBGROUP] ).eval_and_sum(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert loads_a == (2*n*m*ell)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert loads_a == (2*n*m*ell)*n_subgroups global_stores = mem_map.filter_by(mtype=['global'], direction=['store'], count_granularity=[CG.SUBGROUP] ).eval_and_sum(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert global_stores == (n*m*ell + n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert global_stores == (n*m*ell + n*m)*n_subgroups ld_bytes = mem_map.filter_by(mtype=['global'], direction=['load'], count_granularity=[CG.SUBGROUP] @@ -1102,9 +1135,9 @@ def test_summations_and_filters(): count_granularity=[CG.SUBGROUP] ).to_bytes().eval_and_sum(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert ld_bytes == (4*n*m*ell*3 + 8*n*m*2)*n_workgroups*subgroups_per_group - assert st_bytes == (4*n*m*ell + 8*n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert ld_bytes == (4*n*m*ell*3 + 8*n*m*2)*n_subgroups + assert st_bytes == (4*n*m*ell + 8*n*m)*n_subgroups # ignore stride and variable names in this map reduced_map = mem_map.group_by('mtype', 'dtype', 'direction') @@ -1113,11 +1146,11 @@ def test_summations_and_filters(): f64lall = reduced_map[lp.MemAccess('global', np.float64, direction='load') ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32lall == (3*n*m*ell)*n_workgroups*subgroups_per_group - assert f64lall == (2*n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32lall == (3*n*m*ell)*n_subgroups + assert f64lall == (2*n*m)*n_subgroups - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) #for k, v in op_map.items(): # print(type(k), "\n", k.name, k.dtype, type(k.dtype), " :\n", v) @@ -1149,8 +1182,8 @@ def test_summations_and_filters(): key.direction == 'load' f64l = mem_map.filter_by_func(func_filter).eval_and_sum(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f64l == (2*n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f64l == (2*n*m)*n_subgroups def test_strided_footprint():