From 19525df93ff6e82093548f0e1f3aa60b7137a8cb Mon Sep 17 00:00:00 2001
From: jdsteve2 <jdsteve2@illinois.edu>
Date: Sun, 4 Mar 2018 02:37:15 -0600
Subject: [PATCH] updated docstrings

---
 loopy/statistics.py | 37 +++++++++++++++++++------------------
 1 file changed, 19 insertions(+), 18 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 91d15e7e0..07f29d8ab 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -269,7 +269,7 @@ class ToCountMap(object):
             params = {'n': 512, 'm': 256, 'l': 128}
             mem_map = lp.get_mem_access_map(knl)
             def filter_func(key):
-                return key.stride > 1 and key.stride <= 4:
+                return key.lid_strides[0] > 1 and key.lid_strides[0] <= 4:
 
             filtered_map = mem_map.filter_by_func(filter_func)
             tot = filtered_map.eval_and_sum(params)
@@ -374,16 +374,16 @@ class ToCountMap(object):
             params = {'n': 512, 'm': 256, 'l': 128}
 
             s1_g_ld_byt = bytes_map.filter_by(
-                                mtype=['global'], stride=[1],
+                                mtype=['global'], lid_strides={0: 1},
                                 direction=['load']).eval_and_sum(params)
             s2_g_ld_byt = bytes_map.filter_by(
-                                mtype=['global'], stride=[2],
+                                mtype=['global'], lid_strides={0: 2},
                                 direction=['load']).eval_and_sum(params)
             s1_g_st_byt = bytes_map.filter_by(
-                                mtype=['global'], stride=[1],
+                                mtype=['global'], lid_strides={0: 1},
                                 direction=['store']).eval_and_sum(params)
             s2_g_st_byt = bytes_map.filter_by(
-                                mtype=['global'], stride=[2],
+                                mtype=['global'], lid_strides={0: 2},
                                 direction=['store']).eval_and_sum(params)
 
             # (now use these counts to, e.g., predict performance)
@@ -440,7 +440,7 @@ class ToCountMap(object):
             params = {'n': 512, 'm': 256, 'l': 128}
             mem_map = lp.get_mem_access_map(knl)
             filtered_map = mem_map.filter_by(direction=['load'],
-                                             variable=['a','g'])
+                                             variable=['a', 'g'])
             tot_loads_a_g = filtered_map.eval_and_sum(params)
 
             # (now use these counts to, e.g., predict performance)
@@ -553,10 +553,16 @@ class MemAccess(Record):
        A :class:`loopy.LoopyType` or :class:`numpy.dtype` that specifies the
        data type accessed.
 
-    .. attribute:: stride
+    .. attribute:: lid_strides
 
-       An :class:`int` that specifies stride of the memory access. A stride of
-       0 indicates a uniform access (i.e. all work-items access the same item).
+       An :class:`dict` of **{** :class:`int` **:**
+       :class:`pymbolic.primitives.Variable` or :class:`int` **}** that
+       specifies local strides for each local id in the memory access index.
+       Local ids not found will not be present in ``lid_strides.keys()``.
+       Uniform access (i.e. work-items within a sub-group access the same
+       item) is indicated by setting ``lid_strides[0]=0``, but may also occur
+       when no local id 0 is found, in which case the 0 key will not be
+       present in lid_strides.
 
     .. attribute:: direction
 
@@ -966,11 +972,6 @@ class GlobalMemAccessCounter(MemAccessCounter):
                 ltag_stride += stride*coeff_lid
             lid_strides[ltag] = ltag_stride
 
-        # insert 0s for coeffs of missing *lesser* lids
-        #for i in range(max(lid_strides.keys())+1):
-        #    if i not in lid_strides.keys():
-        #        lid_strides[i] = 0
-
         count_granularity = CountGranularity.WORKITEM if (
                                 0 in lid_strides and lid_strides[0] != 0
                                 ) else CountGranularity.SUBGROUP
@@ -1389,7 +1390,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
         f32_s1_g_ld_a = mem_map[MemAccess(
                                     mtype='global',
                                     dtype=np.float32,
-                                    stride=1,
+                                    lid_strides={0: 1},
                                     direction='load',
                                     variable='a',
                                     count_granularity=CountGranularity.WORKITEM)
@@ -1397,7 +1398,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
         f32_s1_g_st_a = mem_map[MemAccess(
                                     mtype='global',
                                     dtype=np.float32,
-                                    stride=1,
+                                    lid_strides={0: 1},
                                     direction='store',
                                     variable='a',
                                     count_granularity=CountGranularity.WORKITEM)
@@ -1405,7 +1406,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
         f32_s1_l_ld_x = mem_map[MemAccess(
                                     mtype='local',
                                     dtype=np.float32,
-                                    stride=1,
+                                    lid_strides={0: 1},
                                     direction='load',
                                     variable='x',
                                     count_granularity=CountGranularity.WORKITEM)
@@ -1413,7 +1414,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
         f32_s1_l_st_x = mem_map[MemAccess(
                                     mtype='local',
                                     dtype=np.float32,
-                                    stride=1,
+                                    lid_strides={0: 1},
                                     direction='store',
                                     variable='x',
                                     count_granularity=CountGranularity.WORKITEM)
-- 
GitLab