From 09420000fb7a565fd3fa64a6ca8e8a609eae8008 Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@porter.cs.illinois.edu>
Date: Sun, 13 Mar 2016 13:23:59 -0500
Subject: [PATCH] subscript counter only looking for lid0 now, if not found,
 setting stride to maxsize

---
 loopy/statistics.py     | 35 ++++++++++++++++++++++++-----------
 test/test_statistics.py |  2 +-
 2 files changed, 25 insertions(+), 12 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 5faeb12e3..e10de8cb4 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -291,15 +291,19 @@ class GlobalSubscriptCounter(CombineMapper):
         my_inames = get_dependencies(index) & self.knl.all_inames()
 
         # find min tag axis
-        import sys
-        min_tag_axis = sys.maxsize
+        #import sys
+        local_id0 = None
+        #min_tag_axis = sys.maxsize
         local_id_found = False
         for iname in my_inames:
             tag = self.knl.iname_to_tag.get(iname)
             if isinstance(tag, LocalIndexTag):
                 local_id_found = True
-                if tag.axis < min_tag_axis:
-                    min_tag_axis = tag.axis
+                #if tag.axis < min_tag_axis:
+                #    min_tag_axis = tag.axis
+                if tag.axis == 0:
+                    local_id0 = iname
+                    break
 
         if not local_id_found:
             # count as uniform access
@@ -307,6 +311,15 @@ class GlobalSubscriptCounter(CombineMapper):
                     {(self.type_inf(expr), DataAccess(stride=0)): 1}
                     ) + self.rec(expr.index)
 
+        if local_id0 is None:
+            # only non-zero local id(s) found, assume non-consecutive access
+            #TODO what to do here?
+            import sys
+            return ToCountMap(
+                    {(self.type_inf(expr), DataAccess(stride=sys.maxsize)): 1}
+                    ) + self.rec(expr.index)
+
+        '''            
         # get local_id associated with minimum tag axis
         min_local_id = None
         for iname in my_inames:
@@ -315,11 +328,11 @@ class GlobalSubscriptCounter(CombineMapper):
                 if tag.axis == min_tag_axis:
                     min_local_id = iname
                     break  # there will be only one min local_id
+        '''
 
-        # found local_id associated with minimum tag axis
-
+        # found local_id associated with axis 0
         total_stride = None
-        # check coefficient of min_local_id for each axis
+        # check coefficient of local_id0 for each axis
         from loopy.symbolic import CoefficientCollector
         from pymbolic.primitives import Variable
         for idx, axis_tag in zip(index, array.dim_tags):
@@ -327,12 +340,12 @@ class GlobalSubscriptCounter(CombineMapper):
             coeffs = CoefficientCollector()(idx)
             # check if he contains the min lid guy
             try:
-                coeff_min_lid = coeffs[Variable(min_local_id)]
+                coeff_lid0 = coeffs[Variable(local_id0)]
             except KeyError:
-                # does not contain min_local_id
+                # does not contain local_id0
                 continue
 
-            # found coefficient of min_local_id
+            # found coefficient of local_id0
             # now determine stride
             from loopy.kernel.array import FixedStrideArrayDimTag
             if isinstance(axis_tag, FixedStrideArrayDimTag):
@@ -340,7 +353,7 @@ class GlobalSubscriptCounter(CombineMapper):
             else:
                 continue
 
-            total_stride = stride*coeff_min_lid
+            total_stride = stride*coeff_lid0
             #TODO is there a case where this^ does not execute, or executes more than once for two different axes?
 
         return ToCountMap({(self.type_inf(expr),
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 6e5b6270b..6aec20444 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -594,7 +594,7 @@ def test_all_counters_parallel_matmul():
 
     subscript_map = get_gmem_access_poly(knl)
     f32uncoal = subscript_map[
-                        (np.dtype(np.float32), DataAccess(stride=Variable('m')), 'load')
+                        (np.dtype(np.float32), DataAccess(stride=sys.maxsize), 'load')
                         ].eval_with_dict(params)
     f32coal = subscript_map[
                         (np.dtype(np.float32), DataAccess(stride=1), 'load')
-- 
GitLab