diff --git a/doc/tutorial.rst b/doc/tutorial.rst index 3c85060dacf03b52f6e0b1faf05ad4697b6a5d07..b4d5fe5a9be28f509c8d52ae4ef54a9c2f662b09 100644 --- a/doc/tutorial.rst +++ b/doc/tutorial.rst @@ -1827,7 +1827,7 @@ criteria are more complicated than a simple list of allowable values: .. doctest:: - >>> def f(key): + >>> def f(key, val): ... from loopy.types import to_loopy_type ... return key.dtype == to_loopy_type(np.float32) and \ ... key.lid_strides[0] > 1 diff --git a/loopy/statistics.py b/loopy/statistics.py index 10d29daad062744ca3fbe2dc2261be4cd2c4ca99..5e7a97f3262762244be671f9e92cf6de527429b3 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -257,11 +257,11 @@ class ToCountMap(object): def filter_by_func(self, func): """Keep items that pass a test. - :arg func: A function that takes a map key a parameter and returns a - :class:`bool`. + :arg func: A function that takes a map key and val as parameters and + returns a :class:`bool`. - :arg: A :class:`ToCountMap` containing the subset of the items in the - original :class:`ToCountMap` for which func(key) is true. + :return: A :class:`ToCountMap` containing the subset of the items in the + original :class:`ToCountMap` for which func(key, val) is true. Example usage:: @@ -269,7 +269,7 @@ class ToCountMap(object): params = {'n': 512, 'm': 256, 'l': 128} mem_map = lp.get_mem_access_map(knl) - def filter_func(key): + def filter_func(key, val): return key.lid_strides[0] > 1 and key.lid_strides[0] <= 4: filtered_map = mem_map.filter_by_func(filter_func) @@ -283,7 +283,7 @@ class ToCountMap(object): # for each item in self.count_map, call func on the key for self_key, self_val in self.items(): - if func(self_key): + if func(self_key, self_val): result_map[self_key] = self_val return result_map @@ -499,7 +499,7 @@ class Op(Record): .. attribute:: name A :class:`str` that specifies the kind of arithmetic operation as - *add*, *mul*, *div*, *pow*, *shift*, *bw* (bitwise), etc. + *add*, *mul*, *madd*, *div*, *pow*, *shift*, *bw* (bitwise), etc. .. attribute:: count_granularity @@ -707,8 +707,9 @@ class CounterBase(CombineMapper): # {{{ ExpressionOpCounter class ExpressionOpCounter(CounterBase): - def __init__(self, knl, count_within_subscripts=True): + def __init__(self, knl, count_madds, count_within_subscripts=True): self.knl = knl + self.count_madds = count_madds self.count_within_subscripts = count_within_subscripts from loopy.type_inference import TypeInferenceMapper self.type_inf = TypeInferenceMapper(knl) @@ -736,13 +737,72 @@ class ExpressionOpCounter(CounterBase): return ToCountMap() def map_sum(self, expr): - assert expr.children - return ToCountMap( - {Op(dtype=self.type_inf(expr), - name='add', - count_granularity=CountGranularity.SUBGROUP): - len(expr.children)-1} - ) + sum(self.rec(child) for child in expr.children) + if not self.count_madds: + return ToCountMap( + {Op(dtype=self.type_inf(expr), + name='add', + count_granularity=CountGranularity.SUBGROUP): + len(expr.children) - 1} + ) + sum(self.rec(child) for child in expr.children) + + # construct ToCountMap + result = ToCountMap() + + # first compute count for madds, adds, and muls in expr and expr.children, + # and recurse on any uncounted expressions + adds_available = len(expr.children)-1 + madd_ct = 0 + child_mul_ct = 0 + sum_expr_type = self.type_inf(expr) + from pymbolic.primitives import Product, is_zero + + for child in expr.children: + # if child is Product w/ matching dtype and unpaired adds exist, + # then consider it as potential source for madds, + # otherwise recurse on child as usual + if isinstance(child, Product) and self.type_inf(child) == sum_expr_type \ + and adds_available > madd_ct: + # process this product as in map_product(), + # but first check to see if one mul can be counted as a madd + + # count muls excluding negation + # i.e., (-1)*x contains 0 muls; (-1)*x*y contains one mul + child_muls_available = len(child.children) - 1 - sum( + is_zero(grandchild + 1) for grandchild in child.children) + + if child_muls_available: + madd_ct += 1 # count one mul as madd + # if there are remaining muls, count them + if child_muls_available > 1: + child_mul_ct += child_muls_available - 1 + + # recurse on grandchildren that are not (-1), as in map_product + result += sum(self.rec(grandchild) for grandchild in child.children + if not is_zero(child + 1)) + else: # not madd + result += self.rec(child) # recurse as usual + + # second, insert op for madds, adds, and muls if count is non-zero + if adds_available > madd_ct: + result += ToCountMap( + {Op(dtype=sum_expr_type, + name='add', + count_granularity=CountGranularity.SUBGROUP): + adds_available - madd_ct}) + if madd_ct: + result += ToCountMap( + {Op(dtype=sum_expr_type, + name='madd', + count_granularity=CountGranularity.SUBGROUP): + madd_ct}) + if child_mul_ct: + result += ToCountMap( + {Op(dtype=sum_expr_type, + name='mul', + count_granularity=CountGranularity.SUBGROUP): + child_mul_ct}) + + return result def map_product(self, expr): from pymbolic.primitives import is_zero @@ -1345,7 +1405,8 @@ def _get_insn_count(knl, insn_id, subgroup_size, count_redundant_work, # {{{ get_op_map def get_op_map(knl, numpy_types=True, count_redundant_work=False, - count_within_subscripts=True, subgroup_size=None): + count_within_subscripts=True, subgroup_size=None, + count_madds=False): """Count the number of operations in a loopy kernel. @@ -1376,6 +1437,9 @@ def get_op_map(knl, numpy_types=True, count_redundant_work=False, attempt to find the sub-group size using the device and, if unsuccessful, will make a wild guess. + :arg count_madds: A :class:`bool` determining whether to count + multiplication followed by addition as a single operation. + :return: A :class:`ToCountMap` of **{** :class:`Op` **:** :class:`islpy.PwQPolynomial` **}**. @@ -1416,7 +1480,7 @@ def get_op_map(knl, numpy_types=True, count_redundant_work=False, knl = preprocess_kernel(knl) op_map = ToCountMap() - op_counter = ExpressionOpCounter(knl, count_within_subscripts) + op_counter = ExpressionOpCounter(knl, count_madds, count_within_subscripts) from loopy.kernel.instruction import ( CallInstruction, CInstruction, Assignment, @@ -1426,12 +1490,13 @@ def get_op_map(knl, numpy_types=True, count_redundant_work=False, if isinstance(insn, (CallInstruction, CInstruction, Assignment)): ops = op_counter(insn.assignee) + op_counter(insn.expression) for key, val in six.iteritems(ops.count_map): - op_map = ( - op_map - + ToCountMap({key: val}) - * _get_insn_count(knl, insn.id, subgroup_size, - count_redundant_work, - key.count_granularity)) + if val != 0: + op_map = ( + op_map + + ToCountMap({key: val}) + * _get_insn_count(knl, insn.id, subgroup_size, + count_redundant_work, + key.count_granularity)) elif isinstance(insn, (NoOpInstruction, BarrierInstruction)): pass diff --git a/test/test_statistics.py b/test/test_statistics.py index 41a88b3864166b81d60ec0468cf9e5fbd07c227c..1c4ef1fd3e3b44440accd74e8405ce84fa7219bd 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -90,7 +90,8 @@ def test_op_counter_reduction(): name="matmul_serial", assumptions="n,m,ell >= 1") knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32)) - op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True) + op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True, + count_madds=True) n_workgroups = 1 group_size = 1 subgroups_per_group = div_ceil(group_size, SGS) @@ -99,15 +100,13 @@ def test_op_counter_reduction(): m = 256 ell = 128 params = {'n': n, 'm': m, 'ell': ell} - f32add = op_map[lp.Op(np.float32, 'add', CG.SUBGROUP)].eval_with_dict(params) - f32mul = op_map[lp.Op(np.dtype(np.float32), 'mul', CG.SUBGROUP) - ].eval_with_dict(params) + f32madd = op_map[lp.Op(np.float32, 'madd', CG.SUBGROUP)].eval_with_dict(params) # (count-per-sub-group)*n_subgroups - assert f32add == f32mul == n*m*ell*n_subgroups + assert f32madd == n*m*ell*n_subgroups op_map_dtype = op_map.group_by('dtype') f32 = op_map_dtype[lp.Op(dtype=np.float32)].eval_with_dict(params) - assert f32 == f32add + f32mul + assert f32 == f32madd def test_op_counter_logic(): @@ -163,7 +162,7 @@ def test_op_counter_specialops(): dict(a=np.float32, b=np.float32, g=np.float64, h=np.float64)) op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True, - count_within_subscripts=True) + count_within_subscripts=True, count_madds=True) n_workgroups = 1 group_size = 1 subgroups_per_group = div_ceil(group_size, SGS) @@ -178,6 +177,8 @@ def test_op_counter_specialops(): f64pow = op_map[lp.Op(np.float64, 'pow', CG.SUBGROUP)].eval_with_dict(params) f64add = op_map[lp.Op(np.dtype(np.float64), 'add', CG.SUBGROUP) ].eval_with_dict(params) + f64madd = op_map[lp.Op(np.dtype(np.float64), 'madd', CG.SUBGROUP) + ].eval_with_dict(params) i32add = op_map[lp.Op(np.dtype(np.int32), 'add', CG.SUBGROUP) ].eval_with_dict(params) f64rsq = op_map[lp.Op(np.dtype(np.float64), 'func:rsqrt', CG.SUBGROUP) @@ -187,7 +188,8 @@ def test_op_counter_specialops(): # (count-per-sub-group)*n_subgroups assert f32div == 2*n*m*ell*n_subgroups assert f32mul == f32add == n*m*ell*n_subgroups - assert f64add == 3*n*m*n_subgroups + assert f64add == 2*n*m*n_subgroups + assert f64madd == n*m*n_subgroups assert f64pow == i32add == f64rsq == f64sin == n*m*n_subgroups @@ -209,7 +211,7 @@ def test_op_counter_bitwise(): g=np.int64, h=np.int64)) op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True, - count_within_subscripts=False) + count_within_subscripts=False, count_madds=False) n_workgroups = 1 group_size = 1 subgroups_per_group = div_ceil(group_size, SGS) @@ -236,6 +238,272 @@ def test_op_counter_bitwise(): assert i64shift == 2*n*m*n_subgroups +def test_op_counter_madd(): + n_workgroups = 1 + group_size = 1 + subgroups_per_group = div_ceil(group_size, SGS) + n_subgroups = n_workgroups*subgroups_per_group + n = 512 + m = 256 + ell = 128 + params = {'n': n, 'm': m, 'ell': ell} + + # ------------------------------------------------------------------------------- + # standard madd counting + + # perform madd + knl = lp.make_kernel( + "[n,m] -> {[i,j]: 0<=i {[i,j]: 0<=i {[i,j]: 0<=i {[i,j]: 0<=i {[i,j]: 0<=i {[i,j]: 0<=i {[i,j]: 0<=i {[i,j]: 0<=i {[i,j]: 0<=i {[i,j]: 0<=i {[i,j]: 0<=i {[i,j]: 0<=i {[i,j,k]: 0<=i {[i,j]: 0<=i {[i,j]: 0<=i {[i,j]: 0<=i