diff --git a/doc/tutorial.rst b/doc/tutorial.rst index 71b8f4389735f921205eb87440d542fb98fe3c18..1272d2a59119725a903fa7cd1a08b7de8629c6f6 100644 --- a/doc/tutorial.rst +++ b/doc/tutorial.rst @@ -1551,13 +1551,13 @@ information provided. Now we will count the operations: .. doctest:: - >>> op_map = lp.get_op_map(knl) + >>> op_map = lp.get_op_map(knl, subgroup_size=32) >>> print(lp.stringify_stats_mapping(op_map)) - Op(np:dtype('float32'), add, workitem) : ... + Op(np:dtype('float32'), add, subgroup) : ... Each line of output will look roughly like:: - Op(np:dtype('float32'), add, workitem) : [l, m, n] -> { l * m * n : l > 0 and m > 0 and n > 0 } + Op(np:dtype('float32'), add, subgroup) : [l, m, n] -> { l * m * n : l > 0 and m > 0 and n > 0 } :func:`loopy.get_op_map` returns a :class:`loopy.ToCountMap` of **{** :class:`loopy.Op` **:** :class:`islpy.PwQPolynomial` **}**. A @@ -1579,12 +1579,12 @@ One way to evaluate these polynomials is with :func:`islpy.eval_with_dict`: >>> param_dict = {'n': 256, 'm': 256, 'l': 8} >>> from loopy.statistics import CountGranularity as CG - >>> f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(param_dict) - >>> f32div = op_map[lp.Op(np.float32, 'div', CG.WORKITEM)].eval_with_dict(param_dict) - >>> f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(param_dict) - >>> f64add = op_map[lp.Op(np.float64, 'add', CG.WORKITEM)].eval_with_dict(param_dict) - >>> f64mul = op_map[lp.Op(np.float64, 'mul', CG.WORKITEM)].eval_with_dict(param_dict) - >>> i32add = op_map[lp.Op(np.int32, 'add', CG.WORKITEM)].eval_with_dict(param_dict) + >>> f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP)].eval_with_dict(param_dict) + >>> f32div = op_map[lp.Op(np.float32, 'div', CG.SUBGROUP)].eval_with_dict(param_dict) + >>> f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP)].eval_with_dict(param_dict) + >>> f64add = op_map[lp.Op(np.float64, 'add', CG.SUBGROUP)].eval_with_dict(param_dict) + >>> f64mul = op_map[lp.Op(np.float64, 'mul', CG.SUBGROUP)].eval_with_dict(param_dict) + >>> i32add = op_map[lp.Op(np.int32, 'add', CG.SUBGROUP)].eval_with_dict(param_dict) >>> print("%i\n%i\n%i\n%i\n%i\n%i" % ... (f32add, f32div, f32mul, f64add, f64mul, i32add)) 524288 diff --git a/loopy/statistics.py b/loopy/statistics.py index cee28b24f8bdd44392f41f437c6042f3aa08ce2c..3fecfb778c81ff9db101abca543ae6992e0b3575 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -32,7 +32,7 @@ from functools import reduce from loopy.kernel.data import ( MultiAssignmentBase, TemporaryVariable, AddressSpace) from loopy.diagnostic import warn_with_kernel, LoopyError -from pytools import Record +from pytools import Record, memoize_method __doc__ = """ @@ -715,7 +715,7 @@ class ExpressionOpCounter(CounterBase): return ToCountMap( {Op(dtype=self.type_inf(expr), name='func:'+str(expr.function), - count_granularity=CountGranularity.WORKITEM): 1} + count_granularity=CountGranularity.SUBGROUP): 1} ) + self.rec(expr.parameters) def map_subscript(self, expr): @@ -726,7 +726,7 @@ class ExpressionOpCounter(CounterBase): return ToCountMap( {Op(dtype=self.type_inf(expr), name='add', - count_granularity=CountGranularity.WORKITEM): + count_granularity=CountGranularity.SUBGROUP): len(expr.children)-1} ) + sum(self.rec(child) for child in expr.children) @@ -735,18 +735,18 @@ class ExpressionOpCounter(CounterBase): assert expr.children return sum(ToCountMap({Op(dtype=self.type_inf(expr), name='mul', - count_granularity=CountGranularity.WORKITEM): 1}) + count_granularity=CountGranularity.SUBGROUP): 1}) + self.rec(child) for child in expr.children if not is_zero(child + 1)) + \ ToCountMap({Op(dtype=self.type_inf(expr), name='mul', - count_granularity=CountGranularity.WORKITEM): -1}) + count_granularity=CountGranularity.SUBGROUP): -1}) def map_quotient(self, expr, *args): return ToCountMap({Op(dtype=self.type_inf(expr), name='div', - count_granularity=CountGranularity.WORKITEM): 1}) \ + count_granularity=CountGranularity.SUBGROUP): 1}) \ + self.rec(expr.numerator) \ + self.rec(expr.denominator) @@ -756,14 +756,14 @@ class ExpressionOpCounter(CounterBase): def map_power(self, expr): return ToCountMap({Op(dtype=self.type_inf(expr), name='pow', - count_granularity=CountGranularity.WORKITEM): 1}) \ + count_granularity=CountGranularity.SUBGROUP): 1}) \ + self.rec(expr.base) \ + self.rec(expr.exponent) def map_left_shift(self, expr): return ToCountMap({Op(dtype=self.type_inf(expr), name='shift', - count_granularity=CountGranularity.WORKITEM): 1}) \ + count_granularity=CountGranularity.SUBGROUP): 1}) \ + self.rec(expr.shiftee) \ + self.rec(expr.shift) @@ -772,13 +772,13 @@ class ExpressionOpCounter(CounterBase): def map_bitwise_not(self, expr): return ToCountMap({Op(dtype=self.type_inf(expr), name='bw', - count_granularity=CountGranularity.WORKITEM): 1}) \ + count_granularity=CountGranularity.SUBGROUP): 1}) \ + self.rec(expr.child) def map_bitwise_or(self, expr): return ToCountMap({Op(dtype=self.type_inf(expr), name='bw', - count_granularity=CountGranularity.WORKITEM): + count_granularity=CountGranularity.SUBGROUP): len(expr.children)-1}) \ + sum(self.rec(child) for child in expr.children) @@ -802,7 +802,7 @@ class ExpressionOpCounter(CounterBase): def map_min(self, expr): return ToCountMap({Op(dtype=self.type_inf(expr), name='maxmin', - count_granularity=CountGranularity.WORKITEM): + count_granularity=CountGranularity.SUBGROUP): len(expr.children)-1}) \ + sum(self.rec(child) for child in expr.children) @@ -918,7 +918,7 @@ class LocalMemAccessCounter(MemAccessCounter): sub_map[MemAccess( mtype='local', dtype=dtype, - count_granularity=CountGranularity.WORKITEM) + count_granularity=CountGranularity.SUBGROUP) ] = 1 return sub_map @@ -938,7 +938,7 @@ class LocalMemAccessCounter(MemAccessCounter): lid_strides=dict(sorted(six.iteritems(lid_strides))), gid_strides=dict(sorted(six.iteritems(gid_strides))), variable=name, - count_granularity=CountGranularity.WORKITEM)] = 1 + count_granularity=CountGranularity.SUBGROUP)] = 1 return sub_map @@ -1255,6 +1255,59 @@ def count_insn_runs(knl, insn, count_redundant_work, disregard_local_axes=False) else: return c + +@memoize_method +def _get_insn_count(knl, insn_id, subgroup_size, count_redundant_work, + count_granularity=CountGranularity.WORKITEM): + insn = knl.id_to_insn[insn_id] + + if count_granularity is None: + warn_with_kernel(knl, "get_insn_count_assumes_granularity", + "get_insn_count: No count granularity passed, " + "assuming %s granularity." + % (CountGranularity.WORKITEM)) + count_granularity == CountGranularity.WORKITEM + + if count_granularity == CountGranularity.WORKITEM: + return count_insn_runs( + knl, insn, count_redundant_work=count_redundant_work, + disregard_local_axes=False) + + ct_disregard_local = count_insn_runs( + knl, insn, disregard_local_axes=True, + count_redundant_work=count_redundant_work) + + if count_granularity == CountGranularity.WORKGROUP: + return ct_disregard_local + elif count_granularity == CountGranularity.SUBGROUP: + # get the group size + from loopy.symbolic import aff_to_expr + _, local_size = knl.get_grid_size_upper_bounds() + workgroup_size = 1 + if local_size: + for size in local_size: + s = aff_to_expr(size) + if not isinstance(s, int): + raise LoopyError("Cannot count insn with %s granularity, " + "work-group size is not integer: %s" + % (CountGranularity.SUBGROUP, local_size)) + workgroup_size *= s + + warn_with_kernel(knl, "insn_count_subgroups_upper_bound", + "get_insn_count: when counting instruction %s with " + "count_granularity=%s, using upper bound for work-group size " + "(%d work-items) to compute sub-groups per work-group. When " + "multiple device programs present, actual sub-group count may be" + "lower." % (insn_id, CountGranularity.SUBGROUP, workgroup_size)) + + from pytools import div_ceil + return ct_disregard_local*div_ceil(workgroup_size, subgroup_size) + else: + # this should not happen since this is enforced in Op/MemAccess + raise ValueError("get_insn_count: count_granularity '%s' is" + "not allowed. count_granularity options: %s" + % (count_granularity, CountGranularity.ALL+[None])) + # }}} @@ -1322,21 +1375,30 @@ def get_op_map(knl, numpy_types=True, count_redundant_work=False, raise LoopyError("Kernel '%s': Using operation counting requires the option " "ignore_boostable_into to be set." % knl.name) + subgroup_size = _process_subgroup_size(knl, subgroup_size) + from loopy.preprocess import preprocess_kernel, infer_unknown_types - from loopy.kernel.instruction import ( - CallInstruction, CInstruction, Assignment, - NoOpInstruction, BarrierInstruction) knl = infer_unknown_types(knl, expect_completion=True) knl = preprocess_kernel(knl) op_map = ToCountMap() op_counter = ExpressionOpCounter(knl) + + from loopy.kernel.instruction import ( + CallInstruction, CInstruction, Assignment, + NoOpInstruction, BarrierInstruction) + for insn in knl.instructions: if isinstance(insn, (CallInstruction, CInstruction, Assignment)): ops = op_counter(insn.assignee) + op_counter(insn.expression) - op_map = op_map + ops*count_insn_runs( - knl, insn, - count_redundant_work=count_redundant_work) + for key, val in six.iteritems(ops.count_map): + op_map = ( + op_map + + ToCountMap({key: val}) + * _get_insn_count(knl, insn.id, subgroup_size, + count_redundant_work, + key.count_granularity)) + elif isinstance(insn, (NoOpInstruction, BarrierInstruction)): pass else: @@ -1374,6 +1436,44 @@ def _find_subgroup_size_for_knl(knl): return None +@memoize_method +def _process_subgroup_size(knl, subgroup_size_requested): + + if isinstance(subgroup_size_requested, int): + return subgroup_size_requested + else: + # try to find subgroup_size + subgroup_size_guess = _find_subgroup_size_for_knl(knl) + + if subgroup_size_requested is None: + if subgroup_size_guess is None: + # 'guess' was not passed and either no target device found + # or get_simd_group_size returned None + raise ValueError("No sub-group size passed, no target device found. " + "Either (1) pass integer value for subgroup_size, " + "(2) ensure that kernel.target is PyOpenClTarget " + "and kernel.target.device is set, or (3) pass " + "subgroup_size='guess' and hope for the best.") + else: + return subgroup_size_guess + + elif subgroup_size_requested == 'guess': + if subgroup_size_guess is None: + # unable to get subgroup_size from device, so guess + subgroup_size_guess = 32 + warn_with_kernel(knl, "get_x_map_guessing_subgroup_size", + "'guess' sub-group size passed, no target device " + "found, wildly guessing that sub-group size is %d." + % (subgroup_size_guess)) + return subgroup_size_guess + else: + return subgroup_size_guess + else: + raise ValueError("Invalid value for subgroup_size: %s. subgroup_size " + "must be integer, 'guess', or, if you're feeling " + "lucky, None." % (subgroup_size_requested)) + + # {{{ get_mem_access_map def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, @@ -1462,100 +1562,14 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, # (now use these counts to, e.g., predict performance) """ - from loopy.preprocess import preprocess_kernel, infer_unknown_types if not knl.options.ignore_boostable_into: raise LoopyError("Kernel '%s': Using operation counting requires the option " "ignore_boostable_into to be set." % knl.name) - if not isinstance(subgroup_size, int): - # try to find subgroup_size - subgroup_size_guess = _find_subgroup_size_for_knl(knl) - - if subgroup_size is None: - if subgroup_size_guess is None: - # 'guess' was not passed and either no target device found - # or get_simd_group_size returned None - raise ValueError("No sub-group size passed, no target device found. " - "Either (1) pass integer value for subgroup_size, " - "(2) ensure that kernel.target is PyOpenClTarget " - "and kernel.target.device is set, or (3) pass " - "subgroup_size='guess' and hope for the best.") - else: - subgroup_size = subgroup_size_guess - - elif subgroup_size == 'guess': - if subgroup_size_guess is None: - # unable to get subgroup_size from device, so guess - subgroup_size = 32 - warn_with_kernel(knl, "get_mem_access_map_guessing_subgroup_size", - "get_mem_access_map: 'guess' sub-group size " - "passed, no target device found, wildly guessing " - "that sub-group size is %d." % (subgroup_size)) - else: - subgroup_size = subgroup_size_guess - else: - raise ValueError("Invalid value for subgroup_size: %s. subgroup_size " - "must be integer, 'guess', or, if you're feeling " - "lucky, None." % (subgroup_size)) - - class CacheHolder(object): - pass - - cache_holder = CacheHolder() - from pytools import memoize_in - - @memoize_in(cache_holder, "insn_count") - def get_insn_count(knl, insn_id, count_granularity=CountGranularity.WORKITEM): - insn = knl.id_to_insn[insn_id] - - if count_granularity is None: - warn_with_kernel(knl, "get_insn_count_assumes_granularity", - "get_insn_count: No count granularity passed for " - "MemAccess, assuming %s granularity." - % (CountGranularity.WORKITEM)) - count_granularity == CountGranularity.WORKITEM - - if count_granularity == CountGranularity.WORKITEM: - return count_insn_runs( - knl, insn, count_redundant_work=count_redundant_work, - disregard_local_axes=False) - - ct_disregard_local = count_insn_runs( - knl, insn, disregard_local_axes=True, - count_redundant_work=count_redundant_work) - - if count_granularity == CountGranularity.WORKGROUP: - return ct_disregard_local - elif count_granularity == CountGranularity.SUBGROUP: - # get the group size - from loopy.symbolic import aff_to_expr - _, local_size = knl.get_grid_size_upper_bounds() - workgroup_size = 1 - if local_size: - for size in local_size: - s = aff_to_expr(size) - if not isinstance(s, int): - raise LoopyError("Cannot count insn with %s granularity, " - "work-group size is not integer: %s" - % (CountGranularity.SUBGROUP, local_size)) - workgroup_size *= s - - warn_with_kernel(knl, "insn_count_subgroups_upper_bound", - "get_insn_count: when counting instruction %s with " - "count_granularity=%s, using upper bound for work-group size " - "(%d work-items) to compute sub-groups per work-group. When " - "multiple device programs present, actual sub-group count may be" - "lower." % (insn_id, CountGranularity.SUBGROUP, workgroup_size)) - - from pytools import div_ceil - return ct_disregard_local*div_ceil(workgroup_size, subgroup_size) - else: - # this should not happen since this is enforced in MemAccess - raise ValueError("get_insn_count: count_granularity '%s' is" - "not allowed. count_granularity options: %s" - % (count_granularity, CountGranularity.ALL+[None])) + subgroup_size = _process_subgroup_size(knl, subgroup_size) + from loopy.preprocess import preprocess_kernel, infer_unknown_types knl = infer_unknown_types(knl, expect_completion=True) knl = preprocess_kernel(knl) @@ -1584,14 +1598,18 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, access_map = ( access_map + ToCountMap({key: val}) - * get_insn_count(knl, insn.id, key.count_granularity)) + * _get_insn_count(knl, insn.id, subgroup_size, + count_redundant_work, + key.count_granularity)) for key, val in six.iteritems(access_assignee.count_map): access_map = ( access_map + ToCountMap({key: val}) - * get_insn_count(knl, insn.id, key.count_granularity)) + * _get_insn_count(knl, insn.id, subgroup_size, + count_redundant_work, + key.count_granularity)) elif isinstance(insn, (NoOpInstruction, BarrierInstruction)): pass diff --git a/test/test_numa_diff.py b/test/test_numa_diff.py index 6b578838d99cb5aa28296619fdec6e8a2359ba0b..15d5ea7c98b6bde2aab89441a908b71324faae16 100644 --- a/test/test_numa_diff.py +++ b/test/test_numa_diff.py @@ -231,7 +231,7 @@ def test_gnuma_horiz_kernel(ctx_factory, ilp_multiple, Nq, opt_level): # noqa if 1: print("OPS") - op_map = lp.get_op_map(hsv) + op_map = lp.get_op_map(hsv, subgroup_size=32) print(lp.stringify_stats_mapping(op_map)) print("MEM") diff --git a/test/test_statistics.py b/test/test_statistics.py index 79c5ec7da0971b534588be3bfcd58a9f5fc8405a..3f2366521673597f0cd7e96a22780ffd2c89bdc1 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -39,6 +39,9 @@ from pymbolic.primitives import Variable from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa +SGS = 32 # Subgroup size + + def test_op_counter_basic(): knl = lp.make_kernel( @@ -54,21 +57,26 @@ def test_op_counter_basic(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(params) - f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(params) - f32div = op_map[lp.Op(np.float32, 'div', CG.WORKITEM)].eval_with_dict(params) - f64mul = op_map[lp.Op(np.dtype(np.float64), 'mul', CG.WORKITEM) + f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP)].eval_with_dict(params) + f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP)].eval_with_dict(params) + f32div = op_map[lp.Op(np.float32, 'div', CG.SUBGROUP)].eval_with_dict(params) + f64mul = op_map[lp.Op(np.dtype(np.float64), 'mul', CG.SUBGROUP) ].eval_with_dict(params) - i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.WORKITEM) + i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP) ].eval_with_dict(params) - assert f32add == f32mul == f32div == n*m*ell - assert f64mul == n*m - assert i32add == n*m*2 + # (count-per-sub-group)*n_subgroups + assert f32add == f32mul == f32div == n*m*ell*n_subgroups + assert f64mul == n*m*n_subgroups + assert i32add == n*m*2*n_subgroups def test_op_counter_reduction(): @@ -81,15 +89,20 @@ def test_op_counter_reduction(): name="matmul_serial", assumptions="n,m,ell >= 1") knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(params) - f32mul = op_map[lp.Op(np.dtype(np.float32), 'mul', CG.WORKITEM) + f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP)].eval_with_dict(params) + f32mul = op_map[lp.Op(np.dtype(np.float32), 'mul', CG.SUBGROUP) ].eval_with_dict(params) - assert f32add == f32mul == n*m*ell + # (count-per-sub-group)*n_subgroups + assert f32add == f32mul == n*m*ell*n_subgroups op_map_dtype = op_map.group_by('dtype') f32 = op_map_dtype[lp.Op(dtype=np.float32)].eval_with_dict(params) @@ -111,21 +124,26 @@ def test_op_counter_logic(): name="logic", assumptions="n,m,ell >= 1") knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64)) - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(params) - f64add = op_map[lp.Op(np.float64, 'add', CG.WORKITEM)].eval_with_dict(params) - f64div = op_map[lp.Op(np.dtype(np.float64), 'div', CG.WORKITEM) + f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP)].eval_with_dict(params) + f64add = op_map[lp.Op(np.float64, 'add', CG.SUBGROUP)].eval_with_dict(params) + f64div = op_map[lp.Op(np.dtype(np.float64), 'div', CG.SUBGROUP) ].eval_with_dict(params) - i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.WORKITEM) + i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP) ].eval_with_dict(params) - assert f32mul == n*m - assert f64div == 2*n*m # TODO why? - assert f64add == n*m - assert i32add == n*m + # (count-per-sub-group)*n_subgroups + assert f32mul == n*m*n_subgroups + assert f64div == 2*n*m*n_subgroups # TODO why? + assert f64add == n*m*n_subgroups + assert i32add == n*m*n_subgroups def test_op_counter_specialops(): @@ -143,27 +161,32 @@ def test_op_counter_specialops(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - f32mul = op_map[lp.Op(np.float32, 'mul', CG.WORKITEM)].eval_with_dict(params) - f32div = op_map[lp.Op(np.float32, 'div', CG.WORKITEM)].eval_with_dict(params) - f32add = op_map[lp.Op(np.float32, 'add', CG.WORKITEM)].eval_with_dict(params) - f64pow = op_map[lp.Op(np.float64, 'pow', CG.WORKITEM)].eval_with_dict(params) - f64add = op_map[lp.Op(np.dtype(np.float64), 'add', CG.WORKITEM) + f32mul = op_map[lp.Op(np.float32, 'mul', CG.SUBGROUP)].eval_with_dict(params) + f32div = op_map[lp.Op(np.float32, 'div', CG.SUBGROUP)].eval_with_dict(params) + f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP)].eval_with_dict(params) + f64pow = op_map[lp.Op(np.float64, 'pow', CG.SUBGROUP)].eval_with_dict(params) + f64add = op_map[lp.Op(np.dtype(np.float64), 'add', CG.SUBGROUP) ].eval_with_dict(params) - i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.WORKITEM) + i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP) ].eval_with_dict(params) - f64rsq = op_map[lp.Op(np.dtype(np.float64), 'func:rsqrt', CG.WORKITEM) + f64rsq = op_map[lp.Op(np.dtype(np.float64), 'func:rsqrt', CG.SUBGROUP) ].eval_with_dict(params) - f64sin = op_map[lp.Op(np.dtype(np.float64), 'func:sin', CG.WORKITEM) + f64sin = op_map[lp.Op(np.dtype(np.float64), 'func:sin', CG.SUBGROUP) ].eval_with_dict(params) - assert f32div == 2*n*m*ell - assert f32mul == f32add == n*m*ell - assert f64add == 3*n*m - assert f64pow == i32add == f64rsq == f64sin == n*m + # (count-per-sub-group)*n_subgroups + assert f32div == 2*n*m*ell*n_subgroups + assert f32mul == f32add == n*m*ell*n_subgroups + assert f64add == 3*n*m*n_subgroups + assert f64pow == i32add == f64rsq == f64sin == n*m*n_subgroups def test_op_counter_bitwise(): @@ -183,26 +206,31 @@ def test_op_counter_bitwise(): a=np.int32, b=np.int32, g=np.int64, h=np.int64)) - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group n = 512 m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - i32add = op_map[lp.Op(np.int32, 'add', CG.WORKITEM)].eval_with_dict(params) - i32bw = op_map[lp.Op(np.int32, 'bw', CG.WORKITEM)].eval_with_dict(params) - i64bw = op_map[lp.Op(np.dtype(np.int64), 'bw', CG.WORKITEM) + i32add = op_map[lp.Op(np.int32, 'add', CG.SUBGROUP)].eval_with_dict(params) + i32bw = op_map[lp.Op(np.int32, 'bw', CG.SUBGROUP)].eval_with_dict(params) + i64bw = op_map[lp.Op(np.dtype(np.int64), 'bw', CG.SUBGROUP) ].eval_with_dict(params) - i64mul = op_map[lp.Op(np.dtype(np.int64), 'mul', CG.WORKITEM) + i64mul = op_map[lp.Op(np.dtype(np.int64), 'mul', CG.SUBGROUP) ].eval_with_dict(params) - i64add = op_map[lp.Op(np.dtype(np.int64), 'add', CG.WORKITEM) + i64add = op_map[lp.Op(np.dtype(np.int64), 'add', CG.SUBGROUP) ].eval_with_dict(params) - i64shift = op_map[lp.Op(np.dtype(np.int64), 'shift', CG.WORKITEM) + i64shift = op_map[lp.Op(np.dtype(np.int64), 'shift', CG.SUBGROUP) ].eval_with_dict(params) - assert i32add == n*m+n*m*ell - assert i32bw == 2*n*m*ell - assert i64bw == 2*n*m - assert i64add == i64mul == n*m - assert i64shift == 2*n*m + # (count-per-sub-group)*n_subgroups + assert i32add == n*m+n*m*ell*n_subgroups + assert i32bw == 2*n*m*ell*n_subgroups + assert i64bw == 2*n*m*n_subgroups + assert i64add == i64mul == n*m*n_subgroups + assert i64shift == 2*n*m*n_subgroups def test_op_counter_triangular_domain(): @@ -228,15 +256,21 @@ def test_op_counter_triangular_domain(): op_map = lp.get_op_map( knl, + subgroup_size=SGS, count_redundant_work=True - )[lp.Op(np.float64, 'mul', CG.WORKITEM)] + )[lp.Op(np.float64, 'mul', CG.SUBGROUP)] value_dict = dict(m=13, n=200) flops = op_map.eval_with_dict(value_dict) + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group + if expect_fallback: - assert flops == 144 + assert flops == 144*n_subgroups else: - assert flops == 78 + assert flops == 78*n_subgroups def test_mem_access_counter_basic(): @@ -254,10 +288,8 @@ def test_mem_access_counter_basic(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - subgroup_size = 32 - mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) n = 512 m = 256 @@ -266,7 +298,8 @@ def test_mem_access_counter_basic(): n_workgroups = 1 group_size = 1 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group f32l = mem_map[lp.MemAccess('global', np.float32, lid_strides={}, gid_strides={}, @@ -289,9 +322,9 @@ def test_mem_access_counter_basic(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32l == (3*n*m*ell)*n_workgroups*subgroups_per_group - assert f64l == (2*n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32l == (3*n*m*ell)*n_subgroups + assert f64l == (2*n*m)*n_subgroups f32s = mem_map[lp.MemAccess('global', np.dtype(np.float32), lid_strides={}, gid_strides={}, @@ -304,9 +337,9 @@ def test_mem_access_counter_basic(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32s == (n*m*ell)*n_workgroups*subgroups_per_group - assert f64s == (n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32s == (n*m*ell)*n_subgroups + assert f64s == (n*m)*n_subgroups def test_mem_access_counter_reduction(): @@ -320,10 +353,8 @@ def test_mem_access_counter_reduction(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) - subgroup_size = 32 - mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) n = 512 m = 256 ell = 128 @@ -331,7 +362,8 @@ def test_mem_access_counter_reduction(): n_workgroups = 1 group_size = 1 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group f32l = mem_map[lp.MemAccess('global', np.float32, lid_strides={}, gid_strides={}, @@ -344,8 +376,8 @@ def test_mem_access_counter_reduction(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32l == (2*n*m*ell)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32l == (2*n*m*ell)*n_subgroups f32s = mem_map[lp.MemAccess('global', np.dtype(np.float32), lid_strides={}, gid_strides={}, @@ -353,8 +385,8 @@ def test_mem_access_counter_reduction(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32s == (n*ell)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32s == (n*ell)*n_subgroups ld_bytes = mem_map.filter_by(mtype=['global'], direction=['load'] ).to_bytes().eval_and_sum(params) @@ -379,10 +411,8 @@ def test_mem_access_counter_logic(): knl = lp.add_and_infer_dtypes(knl, dict(g=np.float32, h=np.float64)) - subgroup_size = 32 - mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) n = 512 m = 256 ell = 128 @@ -390,7 +420,8 @@ def test_mem_access_counter_logic(): n_workgroups = 1 group_size = 1 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group reduced_map = mem_map.group_by('mtype', 'dtype', 'direction') @@ -404,10 +435,10 @@ def test_mem_access_counter_logic(): direction='store') ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32_g_l == (2*n*m)*n_workgroups*subgroups_per_group - assert f64_g_l == (n*m)*n_workgroups*subgroups_per_group - assert f64_g_s == (n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32_g_l == (2*n*m)*n_subgroups + assert f64_g_l == (n*m)*n_subgroups + assert f64_g_s == (n*m)*n_subgroups def test_mem_access_counter_specialops(): @@ -425,10 +456,8 @@ def test_mem_access_counter_specialops(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - subgroup_size = 32 - mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) n = 512 m = 256 ell = 128 @@ -436,7 +465,8 @@ def test_mem_access_counter_specialops(): n_workgroups = 1 group_size = 1 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group f32 = mem_map[lp.MemAccess('global', np.float32, lid_strides={}, gid_strides={}, @@ -459,9 +489,9 @@ def test_mem_access_counter_specialops(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32 == (2*n*m*ell)*n_workgroups*subgroups_per_group - assert f64 == (2*n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32 == (2*n*m*ell)*n_subgroups + assert f64 == (2*n*m)*n_subgroups f32 = mem_map[lp.MemAccess('global', np.float32, lid_strides={}, gid_strides={}, @@ -474,16 +504,16 @@ def test_mem_access_counter_specialops(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32 == (n*m*ell)*n_workgroups*subgroups_per_group - assert f64 == (n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32 == (n*m*ell)*n_subgroups + assert f64 == (n*m)*n_subgroups filtered_map = mem_map.filter_by(direction=['load'], variable=['a', 'g'], count_granularity=CG.SUBGROUP) tot = filtered_map.eval_and_sum(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert tot == (n*m*ell + n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert tot == (n*m*ell + n*m)*n_subgroups def test_mem_access_counter_bitwise(): @@ -503,10 +533,8 @@ def test_mem_access_counter_bitwise(): a=np.int32, b=np.int32, g=np.int32, h=np.int32)) - subgroup_size = 32 - mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) n = 512 m = 256 ell = 128 @@ -514,7 +542,8 @@ def test_mem_access_counter_bitwise(): n_workgroups = 1 group_size = 1 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group i32 = mem_map[lp.MemAccess('global', np.int32, lid_strides={}, gid_strides={}, @@ -537,8 +566,8 @@ def test_mem_access_counter_bitwise(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert i32 == (4*n*m+2*n*m*ell)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert i32 == (4*n*m+2*n*m*ell)*n_subgroups i32 = mem_map[lp.MemAccess('global', np.int32, lid_strides={}, gid_strides={}, @@ -551,8 +580,8 @@ def test_mem_access_counter_bitwise(): count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert i32 == (n*m+n*m*ell)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert i32 == (n*m+n*m*ell)*n_subgroups def test_mem_access_counter_mixed(): @@ -571,7 +600,6 @@ def test_mem_access_counter_mixed(): x=np.float32)) group_size_0 = 65 - subgroup_size = 32 knl = lp.split_iname(knl, "j", group_size_0) knl = lp.tag_inames(knl, {"j_inner": "l.0", "j_outer": "g.0"}) @@ -583,10 +611,11 @@ def test_mem_access_counter_mixed(): n_workgroups = div_ceil(ell, group_size_0) group_size = group_size_0 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) f64uniform = mem_map[lp.MemAccess('global', np.float64, lid_strides={}, gid_strides={}, direction='load', variable='g', @@ -617,9 +646,9 @@ def test_mem_access_counter_mixed(): count_granularity=CG.WORKITEM) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f64uniform == (2*n*m)*n_workgroups*subgroups_per_group - assert f32uniform == (m*n)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f64uniform == (2*n*m)*n_subgroups + assert f32uniform == (m*n)*n_subgroups expect_fallback = False import islpy as isl @@ -651,8 +680,8 @@ def test_mem_access_counter_mixed(): count_granularity=CG.WORKITEM) ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f64uniform == m*n*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f64uniform == m*n*n_subgroups if expect_fallback: if ell < group_size_0: @@ -681,7 +710,7 @@ def test_mem_access_counter_nonconsec(): knl = lp.tag_inames(knl, {"i_inner": "l.0", "i_outer": "g.0"}) mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=32) # noqa + subgroup_size=SGS) # noqa n = 512 m = 256 ell = 128 @@ -939,30 +968,35 @@ def test_all_counters_parallel_matmul(): m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} + group_size = bsize*bsize + n_workgroups = div_ceil(n, bsize)*div_ceil(ell, bsize) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group sync_map = lp.get_synchronization_map(knl) assert len(sync_map) == 2 assert sync_map["kernel_launch"].eval_with_dict(params) == 1 assert sync_map["barrier_local"].eval_with_dict(params) == 2*m/bsize - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) f32mul = op_map[ - lp.Op(np.float32, 'mul', CG.WORKITEM) + lp.Op(np.float32, 'mul', CG.SUBGROUP) ].eval_with_dict(params) f32add = op_map[ - lp.Op(np.float32, 'add', CG.WORKITEM) + lp.Op(np.float32, 'add', CG.SUBGROUP) ].eval_with_dict(params) i32ops = op_map[ - lp.Op(np.int32, 'add', CG.WORKITEM) + lp.Op(np.int32, 'add', CG.SUBGROUP) ].eval_with_dict(params) i32ops += op_map[ - lp.Op(np.dtype(np.int32), 'mul', CG.WORKITEM) + lp.Op(np.dtype(np.int32), 'mul', CG.SUBGROUP) ].eval_with_dict(params) - assert f32mul+f32add == n*m*ell*2 + # (count-per-sub-group)*n_subgroups + assert f32mul+f32add == m*2*n_subgroups mem_access_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=32) + subgroup_size=SGS) f32s1lb = mem_access_map[lp.MemAccess('global', np.float32, lid_strides={0: 1, 1: Variable('ell')}, @@ -991,33 +1025,36 @@ def test_all_counters_parallel_matmul(): local_mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=32).filter_by(mtype=['local']) + subgroup_size=SGS).filter_by(mtype=['local']) local_mem_l = local_mem_map.filter_by(direction=['load'] ).eval_and_sum(params) - assert local_mem_l == n*m*ell*2 + # (count-per-sub-group)*n_subgroups + assert local_mem_l == m*2*n_subgroups local_mem_l_a = local_mem_map[lp.MemAccess('local', np.dtype(np.float32), direction='load', lid_strides={1: 16}, gid_strides={}, variable='a_fetch', - count_granularity=CG.WORKITEM) + count_granularity=CG.SUBGROUP) ].eval_with_dict(params) local_mem_l_b = local_mem_map[lp.MemAccess('local', np.dtype(np.float32), direction='load', lid_strides={0: 1}, gid_strides={}, variable='b_fetch', - count_granularity=CG.WORKITEM) + count_granularity=CG.SUBGROUP) ].eval_with_dict(params) - assert local_mem_l_a == local_mem_l_b == n*m*ell + # (count-per-sub-group)*n_subgroups + assert local_mem_l_a == local_mem_l_b == m*n_subgroups local_mem_s = local_mem_map.filter_by(direction=['store'] ).eval_and_sum(params) - assert local_mem_s == n*m*ell*2/bsize + # (count-per-sub-group)*n_subgroups + assert local_mem_s == m*2/bsize*n_subgroups def test_gather_access_footprint(): @@ -1067,8 +1104,6 @@ def test_summations_and_filters(): knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - subgroup_size = 32 - n = 512 m = 256 ell = 128 @@ -1076,24 +1111,25 @@ def test_summations_and_filters(): n_workgroups = 1 group_size = 1 - subgroups_per_group = div_ceil(group_size, subgroup_size) + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, - subgroup_size=subgroup_size) + subgroup_size=SGS) loads_a = mem_map.filter_by(direction=['load'], variable=['a'], count_granularity=[CG.SUBGROUP] ).eval_and_sum(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert loads_a == (2*n*m*ell)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert loads_a == (2*n*m*ell)*n_subgroups global_stores = mem_map.filter_by(mtype=['global'], direction=['store'], count_granularity=[CG.SUBGROUP] ).eval_and_sum(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert global_stores == (n*m*ell + n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert global_stores == (n*m*ell + n*m)*n_subgroups ld_bytes = mem_map.filter_by(mtype=['global'], direction=['load'], count_granularity=[CG.SUBGROUP] @@ -1102,9 +1138,9 @@ def test_summations_and_filters(): count_granularity=[CG.SUBGROUP] ).to_bytes().eval_and_sum(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert ld_bytes == (4*n*m*ell*3 + 8*n*m*2)*n_workgroups*subgroups_per_group - assert st_bytes == (4*n*m*ell + 8*n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert ld_bytes == (4*n*m*ell*3 + 8*n*m*2)*n_subgroups + assert st_bytes == (4*n*m*ell + 8*n*m)*n_subgroups # ignore stride and variable names in this map reduced_map = mem_map.group_by('mtype', 'dtype', 'direction') @@ -1113,11 +1149,11 @@ def test_summations_and_filters(): f64lall = reduced_map[lp.MemAccess('global', np.float64, direction='load') ].eval_with_dict(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f32lall == (3*n*m*ell)*n_workgroups*subgroups_per_group - assert f64lall == (2*n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f32lall == (3*n*m*ell)*n_subgroups + assert f64lall == (2*n*m)*n_subgroups - op_map = lp.get_op_map(knl, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) #for k, v in op_map.items(): # print(type(k), "\n", k.name, k.dtype, type(k.dtype), " :\n", v) @@ -1149,8 +1185,8 @@ def test_summations_and_filters(): key.direction == 'load' f64l = mem_map.filter_by_func(func_filter).eval_and_sum(params) - # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert f64l == (2*n*m)*n_workgroups*subgroups_per_group + # uniform: (count-per-sub-group)*n_subgroups + assert f64l == (2*n*m)*n_subgroups def test_strided_footprint():