diff --git a/loopy/statistics.py b/loopy/statistics.py index e27a0f482885658888c97081e4fc1d97fcd149fd..d04b84e3cfda464825334b48b489290d5c025356 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -583,12 +583,12 @@ class MemAccess(Record): """ - def __init__(self, mtype=None, dtype=None, stride=None, direction=None, + def __init__(self, mtype=None, dtype=None, lid_strides=None, direction=None, variable=None, count_granularity=None): - #TODO currently giving all lmem access stride=None - if (mtype == 'local') and (stride is not None): - raise NotImplementedError("MemAccess: stride must be None when " + #TODO currently giving all lmem access lid_strides=None + if mtype == 'local' and lid_strides is not None: + raise NotImplementedError("MemAccess: lid_strides must be None when " "mtype is 'local'") #TODO currently giving all lmem access variable=None @@ -602,14 +602,14 @@ class MemAccess(Record): % (count_granularity, CountGranularity.ALL+[None])) if dtype is None: - Record.__init__(self, mtype=mtype, dtype=dtype, stride=stride, + Record.__init__(self, mtype=mtype, dtype=dtype, lid_strides=lid_strides, direction=direction, variable=variable, count_granularity=count_granularity) else: from loopy.types import to_loopy_type Record.__init__(self, mtype=mtype, dtype=to_loopy_type(dtype), - stride=stride, direction=direction, variable=variable, - count_granularity=count_granularity) + lid_strides=lid_strides, direction=direction, + variable=variable, count_granularity=count_granularity) def __hash__(self): return hash(str(self)) @@ -619,7 +619,7 @@ class MemAccess(Record): return "MemAccess(%s, %s, %s, %s, %s, %s)" % ( self.mtype, self.dtype, - self.stride, + self.lid_strides, self.direction, self.variable, self.count_granularity) @@ -870,7 +870,7 @@ class GlobalMemAccessCounter(MemAccessCounter): return ToCountMap() return ToCountMap({MemAccess(mtype='global', - dtype=self.type_inf(expr), stride=0, + dtype=self.type_inf(expr), lid_strides=[], variable=name, count_granularity=CountGranularity.WORKITEM): 1} ) + self.rec(expr.index) @@ -896,86 +896,87 @@ class GlobalMemAccessCounter(MemAccessCounter): from loopy.kernel.data import LocalIndexTag my_inames = get_dependencies(index) & self.knl.all_inames() - # find min tag axis - import sys - min_tag_axis = sys.maxsize - local_id_found = False + # find all local index tags and corresponding inames + lid_to_iname = {} for iname in my_inames: tag = self.knl.iname_to_tag.get(iname) if isinstance(tag, LocalIndexTag): - local_id_found = True - if tag.axis < min_tag_axis: - min_tag_axis = tag.axis + lid_to_iname[tag.axis] = iname + + if not lid_to_iname: + + # no local id found, count as uniform access + # Note: + # lid_strides=[] when no local ids were found, + # lid_strides=[0, ...] if any local id is found and the lid0 stride is 0, + # either because no lid0 is found or because the stride of lid0 is 0 + warn_with_kernel(self.knl, "no_lid_found", + "GlobalSubscriptCounter: No local id found, " + "setting lid_strides to []. Expression: %s" + % (expr)) - if not local_id_found: - # count as uniform access return ToCountMap({MemAccess( mtype='global', - dtype=self.type_inf(expr), stride=0, + dtype=self.type_inf(expr), lid_strides=[], variable=name, count_granularity=CountGranularity.SUBGROUP): 1} ) + self.rec(expr.index) - if min_tag_axis != 0: - warn_with_kernel(self.knl, "unknown_gmem_stride", - "GlobalSubscriptCounter: Memory access minimum " - "tag axis %d != 0, stride unknown, using " - "sys.maxsize." % (min_tag_axis)) - return ToCountMap({MemAccess( - mtype='global', - dtype=self.type_inf(expr), - stride=sys.maxsize, variable=name, - count_granularity=CountGranularity.WORKITEM): 1} - ) + self.rec(expr.index) - - # get local_id associated with minimum tag axis - min_lid = None - for iname in my_inames: - tag = self.knl.iname_to_tag.get(iname) - if isinstance(tag, LocalIndexTag): - if tag.axis == min_tag_axis: - min_lid = iname - break # there will be only one min local_id - # found local_id associated with minimum tag axis + # create lid_strides dict (strides are coefficents in flattened index) + # i.e., we want {0:A, 1:B, 2:C, ...} where A, B, & C + # come from flattened index [... + C*lid2 + B*lid1 + A*lid0] - total_stride = 0 - # check coefficient of min_lid for each axis from loopy.symbolic import CoefficientCollector from loopy.kernel.array import FixedStrideArrayDimTag from pymbolic.primitives import Variable - for idx, axis_tag in zip(index, array.dim_tags): - from loopy.symbolic import simplify_using_aff - from loopy.diagnostic import ExpressionNotAffineError - try: - coeffs = CoefficientCollector()( - simplify_using_aff(self.knl, idx)) - except ExpressionNotAffineError: - total_stride = None - break - # check if he contains the lid 0 guy - try: - coeff_min_lid = coeffs[Variable(min_lid)] - except KeyError: - # does not contain min_lid - continue - # found coefficient of min_lid - # now determine stride - if isinstance(axis_tag, FixedStrideArrayDimTag): - stride = axis_tag.stride - else: - continue + lid_strides = {} + + for ltag, iname in six.iteritems(lid_to_iname): + ltag_stride = 0 + # check coefficient of this lid for each axis + for idx, axis_tag in zip(index, array.dim_tags): + + from loopy.symbolic import simplify_using_aff + from loopy.diagnostic import ExpressionNotAffineError + try: + coeffs = CoefficientCollector()( + simplify_using_aff(self.knl, idx)) + except ExpressionNotAffineError: + ltag_stride = None + break + + # check if idx contains this lid + try: + coeff_lid = coeffs[Variable(lid_to_iname[ltag])] + except KeyError: + # idx does not contain this lid + continue + + # found coefficient of this lid + # now determine stride + if isinstance(axis_tag, FixedStrideArrayDimTag): + stride = axis_tag.stride + else: + continue - total_stride += stride*coeff_min_lid + ltag_stride += stride*coeff_lid + lid_strides[ltag] = ltag_stride - count_granularity = CountGranularity.WORKITEM if total_stride is not 0 \ + # insert 0s for coeffs of missing *lesser* lids + for i in range(max(lid_strides.keys())+1): + if i not in lid_strides.keys(): + lid_strides[i] = 0 + + count_granularity = CountGranularity.WORKITEM if lid_strides[0] != 0 \ else CountGranularity.SUBGROUP return ToCountMap({MemAccess( mtype='global', dtype=self.type_inf(expr), - stride=total_stride, + lid_strides=[lid_strides[i] + for i in sorted(lid_strides)], variable=name, count_granularity=count_granularity ): 1} @@ -1532,14 +1533,12 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, + access_counter_l(insn.assignee) ).with_set_attributes(direction="store") - # use count excluding local index tags for uniform accesses for key, val in six.iteritems(access_expr.count_map): access_map = ( access_map + ToCountMap({key: val}) * get_insn_count(knl, insn.id, key.count_granularity)) - #currently not counting stride of local mem access for key, val in six.iteritems(access_assignee.count_map): @@ -1547,7 +1546,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, access_map + ToCountMap({key: val}) * get_insn_count(knl, insn.id, key.count_granularity)) - # for now, don't count writes to local mem + elif isinstance(insn, (NoOpInstruction, BarrierInstruction)): pass else: @@ -1560,7 +1559,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, (MemAccess( mtype=mem_access.mtype, dtype=mem_access.dtype.numpy_dtype, - stride=mem_access.stride, + lid_strides=mem_access.lid_strides, direction=mem_access.direction, variable=mem_access.variable, count_granularity=mem_access.count_granularity), diff --git a/test/test_statistics.py b/test/test_statistics.py index ea0bdb62bb75d8a5bcf7dd987c00c33b848091fd..732b9afe2ce7e8db01185be0e152fa3e43975eea 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -269,19 +269,19 @@ def test_mem_access_counter_basic(): subgroups_per_group = div_ceil(group_size, subgroup_size) f32l = mem_map[lp.MemAccess('global', np.float32, - stride=0, direction='load', variable='a', + lid_strides=[], direction='load', variable='a', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f32l += mem_map[lp.MemAccess('global', np.float32, - stride=0, direction='load', variable='b', + lid_strides=[], direction='load', variable='b', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f64l = mem_map[lp.MemAccess('global', np.float64, - stride=0, direction='load', variable='g', + lid_strides=[], direction='load', variable='g', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f64l += mem_map[lp.MemAccess('global', np.float64, - stride=0, direction='load', variable='h', + lid_strides=[], direction='load', variable='h', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) @@ -290,11 +290,11 @@ def test_mem_access_counter_basic(): assert f64l == (2*n*m)*n_workgroups*subgroups_per_group f32s = mem_map[lp.MemAccess('global', np.dtype(np.float32), - stride=0, direction='store', variable='c', + lid_strides=[], direction='store', variable='c', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f64s = mem_map[lp.MemAccess('global', np.dtype(np.float64), - stride=0, direction='store', variable='e', + lid_strides=[], direction='store', variable='e', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) @@ -328,11 +328,11 @@ def test_mem_access_counter_reduction(): subgroups_per_group = div_ceil(group_size, subgroup_size) f32l = mem_map[lp.MemAccess('global', np.float32, - stride=0, direction='load', variable='a', + lid_strides=[], direction='load', variable='a', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f32l += mem_map[lp.MemAccess('global', np.float32, - stride=0, direction='load', variable='b', + lid_strides=[], direction='load', variable='b', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) @@ -340,7 +340,7 @@ def test_mem_access_counter_reduction(): assert f32l == (2*n*m*ell)*n_workgroups*subgroups_per_group f32s = mem_map[lp.MemAccess('global', np.dtype(np.float32), - stride=0, direction='store', variable='c', + lid_strides=[], direction='store', variable='c', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) @@ -430,19 +430,19 @@ def test_mem_access_counter_specialops(): subgroups_per_group = div_ceil(group_size, subgroup_size) f32 = mem_map[lp.MemAccess('global', np.float32, - stride=0, direction='load', variable='a', + lid_strides=[], direction='load', variable='a', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f32 += mem_map[lp.MemAccess('global', np.float32, - stride=0, direction='load', variable='b', + lid_strides=[], direction='load', variable='b', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f64 = mem_map[lp.MemAccess('global', np.dtype(np.float64), - stride=0, direction='load', variable='g', + lid_strides=[], direction='load', variable='g', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f64 += mem_map[lp.MemAccess('global', np.dtype(np.float64), - stride=0, direction='load', variable='h', + lid_strides=[], direction='load', variable='h', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) @@ -451,11 +451,11 @@ def test_mem_access_counter_specialops(): assert f64 == (2*n*m)*n_workgroups*subgroups_per_group f32 = mem_map[lp.MemAccess('global', np.float32, - stride=0, direction='store', variable='c', + lid_strides=[], direction='store', variable='c', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f64 = mem_map[lp.MemAccess('global', np.float64, - stride=0, direction='store', variable='e', + lid_strides=[], direction='store', variable='e', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) @@ -502,19 +502,19 @@ def test_mem_access_counter_bitwise(): subgroups_per_group = div_ceil(group_size, subgroup_size) i32 = mem_map[lp.MemAccess('global', np.int32, - stride=0, direction='load', variable='a', + lid_strides=[], direction='load', variable='a', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) i32 += mem_map[lp.MemAccess('global', np.int32, - stride=0, direction='load', variable='b', + lid_strides=[], direction='load', variable='b', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) i32 += mem_map[lp.MemAccess('global', np.int32, - stride=0, direction='load', variable='g', + lid_strides=[], direction='load', variable='g', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) i32 += mem_map[lp.MemAccess('global', np.dtype(np.int32), - stride=0, direction='load', variable='h', + lid_strides=[], direction='load', variable='h', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) @@ -522,11 +522,11 @@ def test_mem_access_counter_bitwise(): assert i32 == (4*n*m+2*n*m*ell)*n_workgroups*subgroups_per_group i32 = mem_map[lp.MemAccess('global', np.int32, - stride=0, direction='store', variable='c', + lid_strides=[], direction='store', variable='c', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) i32 += mem_map[lp.MemAccess('global', np.int32, - stride=0, direction='store', variable='e', + lid_strides=[], direction='store', variable='e', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) @@ -567,24 +567,24 @@ def test_mem_access_counter_mixed(): mem_map = lp.get_mem_access_map(knl, count_redundant_work=True, subgroup_size=subgroup_size) f64uniform = mem_map[lp.MemAccess('global', np.float64, - stride=0, direction='load', variable='g', + lid_strides=[], direction='load', variable='g', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f64uniform += mem_map[lp.MemAccess('global', np.float64, - stride=0, direction='load', variable='h', + lid_strides=[], direction='load', variable='h', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f32uniform = mem_map[lp.MemAccess('global', np.float32, - stride=0, direction='load', variable='x', + lid_strides=[], direction='load', variable='x', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f32nonconsec = mem_map[lp.MemAccess('global', np.dtype(np.float32), - stride=Variable('m'), direction='load', + lid_strides=[Variable('m')], direction='load', variable='a', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f32nonconsec += mem_map[lp.MemAccess('global', np.dtype(np.float32), - stride=Variable('m'), direction='load', + lid_strides=[Variable('m')], direction='load', variable='b', count_granularity=CG.WORKITEM) ].eval_with_dict(params) @@ -611,11 +611,11 @@ def test_mem_access_counter_mixed(): assert f32nonconsec == 3*n*m*ell f64uniform = mem_map[lp.MemAccess('global', np.float64, - stride=0, direction='store', variable='e', + lid_strides=[], direction='store', variable='e', count_granularity=CG.SUBGROUP) ].eval_with_dict(params) f32nonconsec = mem_map[lp.MemAccess('global', np.float32, - stride=Variable('m'), direction='store', + lid_strides=[Variable('m')], direction='store', variable='c', count_granularity=CG.WORKITEM) ].eval_with_dict(params) @@ -655,22 +655,22 @@ def test_mem_access_counter_nonconsec(): ell = 128 params = {'n': n, 'm': m, 'ell': ell} f64nonconsec = mem_map[lp.MemAccess('global', np.float64, - stride=Variable('m'), direction='load', + lid_strides=[Variable('m')], direction='load', variable='g', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f64nonconsec += mem_map[lp.MemAccess('global', np.float64, - stride=Variable('m'), direction='load', + lid_strides=[Variable('m')], direction='load', variable='h', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f32nonconsec = mem_map[lp.MemAccess('global', np.dtype(np.float32), - stride=Variable('m')*Variable('ell'), + lid_strides=[Variable('m')*Variable('ell')], direction='load', variable='a', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f32nonconsec += mem_map[lp.MemAccess('global', np.dtype(np.float32), - stride=Variable('m')*Variable('ell'), + lid_strides=[Variable('m')*Variable('ell')], direction='load', variable='b', count_granularity=CG.WORKITEM) ].eval_with_dict(params) @@ -678,12 +678,12 @@ def test_mem_access_counter_nonconsec(): assert f32nonconsec == 3*n*m*ell f64nonconsec = mem_map[lp.MemAccess('global', np.float64, - stride=Variable('m'), direction='store', + lid_strides=[Variable('m')], direction='store', variable='e', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f32nonconsec = mem_map[lp.MemAccess('global', np.float32, - stride=Variable('m')*Variable('ell'), + lid_strides=[Variable('m')*Variable('ell')], direction='store', variable='c', count_granularity=CG.WORKITEM) ].eval_with_dict(params) @@ -694,20 +694,20 @@ def test_mem_access_counter_nonconsec(): subgroup_size=64) f64nonconsec = mem_map64[lp.MemAccess( 'global', - np.float64, stride=Variable('m'), + np.float64, lid_strides=[Variable('m')], direction='load', variable='g', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f64nonconsec += mem_map64[lp.MemAccess( 'global', - np.float64, stride=Variable('m'), + np.float64, lid_strides=[Variable('m')], direction='load', variable='h', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f32nonconsec = mem_map64[lp.MemAccess( 'global', np.dtype(np.float32), - stride=Variable('m')*Variable('ell'), + lid_strides=[Variable('m')*Variable('ell')], direction='load', variable='a', count_granularity=CG.WORKITEM) @@ -715,7 +715,7 @@ def test_mem_access_counter_nonconsec(): f32nonconsec += mem_map64[lp.MemAccess( 'global', np.dtype(np.float32), - stride=Variable('m')*Variable('ell'), + lid_strides=[Variable('m')*Variable('ell')], direction='load', variable='b', count_granularity=CG.WORKITEM) @@ -747,30 +747,30 @@ def test_mem_access_counter_consec(): params = {'n': n, 'm': m, 'ell': ell} f64consec = mem_map[lp.MemAccess('global', np.float64, - stride=1, direction='load', variable='g', + lid_strides=[1], direction='load', variable='g', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f64consec += mem_map[lp.MemAccess('global', np.float64, - stride=1, direction='load', variable='h', + lid_strides=[1], direction='load', variable='h', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f32consec = mem_map[lp.MemAccess('global', np.float32, - stride=1, direction='load', variable='a', + lid_strides=[1], direction='load', variable='a', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f32consec += mem_map[lp.MemAccess('global', np.dtype(np.float32), - stride=1, direction='load', variable='b', + lid_strides=[1], direction='load', variable='b', count_granularity=CG.WORKITEM) ].eval_with_dict(params) assert f64consec == 2*n*m*ell assert f32consec == 3*n*m*ell f64consec = mem_map[lp.MemAccess('global', np.float64, - stride=1, direction='store', variable='e', + lid_strides=[1], direction='store', variable='e', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f32consec = mem_map[lp.MemAccess('global', np.float32, - stride=1, direction='store', variable='c', + lid_strides=[1], direction='store', variable='c', count_granularity=CG.WORKITEM) ].eval_with_dict(params) assert f64consec == n*m*ell @@ -853,7 +853,6 @@ def test_barrier_counter_barriers(): def test_all_counters_parallel_matmul(): - bsize = 16 knl = lp.make_kernel( "{[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<ell}", @@ -898,11 +897,11 @@ def test_all_counters_parallel_matmul(): subgroup_size=32) f32s1lb = mem_access_map[lp.MemAccess('global', np.float32, - stride=1, direction='load', variable='b', + lid_strides=[1, Variable('ell')], direction='load', variable='b', count_granularity=CG.WORKITEM) ].eval_with_dict(params) f32s1la = mem_access_map[lp.MemAccess('global', np.float32, - stride=1, direction='load', variable='a', + lid_strides=[1, Variable('m')], direction='load', variable='a', count_granularity=CG.WORKITEM) ].eval_with_dict(params) @@ -910,7 +909,7 @@ def test_all_counters_parallel_matmul(): assert f32s1la == n*m*ell/bsize f32coal = mem_access_map[lp.MemAccess('global', np.float32, - stride=1, direction='store', variable='c', + lid_strides=[1, Variable('ell')], direction='store', variable='c', count_granularity=CG.WORKITEM) ].eval_with_dict(params) @@ -1057,12 +1056,12 @@ def test_summations_and_filters(): assert f64ops_all == n*m def func_filter(key): - return key.stride < 1 and key.dtype == to_loopy_type(np.float64) and \ + return key.lid_strides == [] and key.dtype == to_loopy_type(np.float64) and \ key.direction == 'load' - s1f64l = mem_map.filter_by_func(func_filter).eval_and_sum(params) + f64l = mem_map.filter_by_func(func_filter).eval_and_sum(params) # uniform: (count-per-sub-group)*n_workgroups*subgroups_per_group - assert s1f64l == (2*n*m)*n_workgroups*subgroups_per_group + assert f64l == (2*n*m)*n_workgroups*subgroups_per_group def test_strided_footprint():