From 2dbeb3877549f4564d96aae3314ab6636d8c8a56 Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@porter.cs.illinois.edu>
Date: Tue, 15 Mar 2016 18:55:39 -0500
Subject: [PATCH] now calculating strides greater than 1

---
 loopy/statistics.py     | 56 +++++++++++++++++++++--------------------
 test/test_statistics.py |  2 +-
 2 files changed, 30 insertions(+), 28 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index e10de8cb4..eff571668 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -291,19 +291,19 @@ class GlobalSubscriptCounter(CombineMapper):
         my_inames = get_dependencies(index) & self.knl.all_inames()
 
         # find min tag axis
-        #import sys
-        local_id0 = None
-        #min_tag_axis = sys.maxsize
+        import sys
+        #local_id0 = None
+        min_tag_axis = sys.maxsize
         local_id_found = False
         for iname in my_inames:
             tag = self.knl.iname_to_tag.get(iname)
             if isinstance(tag, LocalIndexTag):
                 local_id_found = True
-                #if tag.axis < min_tag_axis:
-                #    min_tag_axis = tag.axis
-                if tag.axis == 0:
-                    local_id0 = iname
-                    break
+                if tag.axis < min_tag_axis:
+                    min_tag_axis = tag.axis
+                #if tag.axis == 0:
+                #    local_id0 = iname
+                #    break
 
         if not local_id_found:
             # count as uniform access
@@ -311,49 +311,51 @@ class GlobalSubscriptCounter(CombineMapper):
                     {(self.type_inf(expr), DataAccess(stride=0)): 1}
                     ) + self.rec(expr.index)
 
-        if local_id0 is None:
-            # only non-zero local id(s) found, assume non-consecutive access
-            #TODO what to do here?
-            import sys
-            return ToCountMap(
-                    {(self.type_inf(expr), DataAccess(stride=sys.maxsize)): 1}
-                    ) + self.rec(expr.index)
-
-        '''            
         # get local_id associated with minimum tag axis
-        min_local_id = None
+        min_lid = None
         for iname in my_inames:
             tag = self.knl.iname_to_tag.get(iname)
             if isinstance(tag, LocalIndexTag):
                 if tag.axis == min_tag_axis:
-                    min_local_id = iname
+                    min_lid = iname
                     break  # there will be only one min local_id
-        '''
 
-        # found local_id associated with axis 0
+        # found local_id associated with minimum tag axis
+
         total_stride = None
         # check coefficient of local_id0 for each axis
         from loopy.symbolic import CoefficientCollector
         from pymbolic.primitives import Variable
+        #print("==========================================================================================")
+        #print("expr: ", expr)
+        #print("min_lid: ", min_lid)
+        #print("min_tag_axis: ", min_tag_axis)
+        #print("Var(min_lid): ", Variable(min_lid))
         for idx, axis_tag in zip(index, array.dim_tags):
-
+            #print("...........................................................................................")
+            #print("idx, axis_tag: ", idx, "\t",  axis_tag)
             coeffs = CoefficientCollector()(idx)
+            #print("coeffs: ", coeffs)
             # check if he contains the min lid guy
             try:
-                coeff_lid0 = coeffs[Variable(local_id0)]
+                coeff_min_lid = coeffs[Variable(min_lid)]
             except KeyError:
-                # does not contain local_id0
+                # does not contain min_lid
+                #print("key error")
                 continue
-
-            # found coefficient of local_id0
+            #print("coeff_min_lid: ", coeff_min_lid)
+            #print("axis_tag: ", axis_tag)
+            # found coefficient of min_lid
             # now determine stride
             from loopy.kernel.array import FixedStrideArrayDimTag
             if isinstance(axis_tag, FixedStrideArrayDimTag):
                 stride = axis_tag.stride
             else:
+                #print("continuing")
                 continue
+            #print("stride: ", stride)
 
-            total_stride = stride*coeff_lid0
+            total_stride = stride*coeff_min_lid
             #TODO is there a case where this^ does not execute, or executes more than once for two different axes?
 
         return ToCountMap({(self.type_inf(expr),
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 6aec20444..6e5b6270b 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -594,7 +594,7 @@ def test_all_counters_parallel_matmul():
 
     subscript_map = get_gmem_access_poly(knl)
     f32uncoal = subscript_map[
-                        (np.dtype(np.float32), DataAccess(stride=sys.maxsize), 'load')
+                        (np.dtype(np.float32), DataAccess(stride=Variable('m')), 'load')
                         ].eval_with_dict(params)
     f32coal = subscript_map[
                         (np.dtype(np.float32), DataAccess(stride=1), 'load')
-- 
GitLab