diff --git a/loopy/statistics.py b/loopy/statistics.py index 91d15e7e079d250ec879fb84be45e7bb40fe4320..07f29d8ab5c39da02b3267464f2a0fa462b570ba 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -269,7 +269,7 @@ class ToCountMap(object): params = {'n': 512, 'm': 256, 'l': 128} mem_map = lp.get_mem_access_map(knl) def filter_func(key): - return key.stride > 1 and key.stride <= 4: + return key.lid_strides[0] > 1 and key.lid_strides[0] <= 4: filtered_map = mem_map.filter_by_func(filter_func) tot = filtered_map.eval_and_sum(params) @@ -374,16 +374,16 @@ class ToCountMap(object): params = {'n': 512, 'm': 256, 'l': 128} s1_g_ld_byt = bytes_map.filter_by( - mtype=['global'], stride=[1], + mtype=['global'], lid_strides={0: 1}, direction=['load']).eval_and_sum(params) s2_g_ld_byt = bytes_map.filter_by( - mtype=['global'], stride=[2], + mtype=['global'], lid_strides={0: 2}, direction=['load']).eval_and_sum(params) s1_g_st_byt = bytes_map.filter_by( - mtype=['global'], stride=[1], + mtype=['global'], lid_strides={0: 1}, direction=['store']).eval_and_sum(params) s2_g_st_byt = bytes_map.filter_by( - mtype=['global'], stride=[2], + mtype=['global'], lid_strides={0: 2}, direction=['store']).eval_and_sum(params) # (now use these counts to, e.g., predict performance) @@ -440,7 +440,7 @@ class ToCountMap(object): params = {'n': 512, 'm': 256, 'l': 128} mem_map = lp.get_mem_access_map(knl) filtered_map = mem_map.filter_by(direction=['load'], - variable=['a','g']) + variable=['a', 'g']) tot_loads_a_g = filtered_map.eval_and_sum(params) # (now use these counts to, e.g., predict performance) @@ -553,10 +553,16 @@ class MemAccess(Record): A :class:`loopy.LoopyType` or :class:`numpy.dtype` that specifies the data type accessed. - .. attribute:: stride + .. attribute:: lid_strides - An :class:`int` that specifies stride of the memory access. A stride of - 0 indicates a uniform access (i.e. all work-items access the same item). + An :class:`dict` of **{** :class:`int` **:** + :class:`pymbolic.primitives.Variable` or :class:`int` **}** that + specifies local strides for each local id in the memory access index. + Local ids not found will not be present in ``lid_strides.keys()``. + Uniform access (i.e. work-items within a sub-group access the same + item) is indicated by setting ``lid_strides[0]=0``, but may also occur + when no local id 0 is found, in which case the 0 key will not be + present in lid_strides. .. attribute:: direction @@ -966,11 +972,6 @@ class GlobalMemAccessCounter(MemAccessCounter): ltag_stride += stride*coeff_lid lid_strides[ltag] = ltag_stride - # insert 0s for coeffs of missing *lesser* lids - #for i in range(max(lid_strides.keys())+1): - # if i not in lid_strides.keys(): - # lid_strides[i] = 0 - count_granularity = CountGranularity.WORKITEM if ( 0 in lid_strides and lid_strides[0] != 0 ) else CountGranularity.SUBGROUP @@ -1389,7 +1390,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, f32_s1_g_ld_a = mem_map[MemAccess( mtype='global', dtype=np.float32, - stride=1, + lid_strides={0: 1}, direction='load', variable='a', count_granularity=CountGranularity.WORKITEM) @@ -1397,7 +1398,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, f32_s1_g_st_a = mem_map[MemAccess( mtype='global', dtype=np.float32, - stride=1, + lid_strides={0: 1}, direction='store', variable='a', count_granularity=CountGranularity.WORKITEM) @@ -1405,7 +1406,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, f32_s1_l_ld_x = mem_map[MemAccess( mtype='local', dtype=np.float32, - stride=1, + lid_strides={0: 1}, direction='load', variable='x', count_granularity=CountGranularity.WORKITEM) @@ -1413,7 +1414,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, f32_s1_l_st_x = mem_map[MemAccess( mtype='local', dtype=np.float32, - stride=1, + lid_strides={0: 1}, direction='store', variable='x', count_granularity=CountGranularity.WORKITEM)