From 7f3442f59dc956d59f3efc44d66737ff7e848a6c Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@porter.cs.illinois.edu>
Date: Sat, 30 Jan 2016 23:51:38 -0600
Subject: [PATCH] uniform loads and stores now only counted once per thread
 group

---
 loopy/statistics.py     | 35 ++++++++++++++++++++++++++++++-----
 test/test_statistics.py | 12 +++++++++---
 2 files changed, 39 insertions(+), 8 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 992c95a4e..e8ff22412 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -606,20 +606,45 @@ def get_gmem_access_poly(knl):  # for now just counting subscripts
     subs_poly = ToCountMap()
     subscript_counter = GlobalSubscriptCounter(knl)
     for insn in knl.instructions:
-        insn_inames = knl.insn_inames(insn)
-        inames_domain = knl.get_inames_domain(insn_inames)
-        domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
+        # count subscripts, distinguishing loads and stores
         subs_expr = subscript_counter(insn.expression)
         subs_expr = ToCountMap(dict(
             (key + ("load",), val)
             for key, val in six.iteritems(subs_expr.dict)))
-
         subs_assignee = subscript_counter(insn.assignee)
         subs_assignee = ToCountMap(dict(
             (key + ("store",), val)
             for key, val in six.iteritems(subs_assignee.dict)))
 
-        subs_poly = subs_poly + (subs_expr + subs_assignee)*count(knl, domain)
+        # get count including local index tags
+        insn_inames = knl.insn_inames(insn)
+        inames_domain = knl.get_inames_domain(insn_inames)
+        domain = (inames_domain.project_out_except(insn_inames, [dim_type.set]))
+        count_all = count(knl, domain)
+
+        # get count excluding local index tags
+        from loopy.kernel.data import LocalIndexTag
+        insn_inames_nonlocal = [iname for iname in insn_inames if not
+                                isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)]
+        inames_domain_nonlocal = knl.get_inames_domain(insn_inames_nonlocal)
+        domain_nonlocal = (inames_domain_nonlocal.project_out_except(
+                                insn_inames_nonlocal, [dim_type.set]))
+        count_nonlocal = count(knl, domain_nonlocal)
+
+        # use count excluding local index tags for uniform accesses
+        for key in subs_expr.dict:
+            poly = ToCountMap({key: subs_expr.dict[key]})
+            if key[1] == "uniform":
+                subs_poly = subs_poly + poly*count_nonlocal
+            else:
+                subs_poly = subs_poly + poly*count_all
+        for key in subs_assignee.dict:
+            poly = ToCountMap({key: subs_assignee.dict[key]})
+            if key[1] == "uniform":
+                subs_poly = subs_poly + poly*count_nonlocal
+            else:
+                subs_poly = subs_poly + poly*count_all
+        #subs_poly = subs_poly + (subs_expr + subs_assignee)*count(knl, domain)
     return subs_poly.dict
 
 
diff --git a/test/test_statistics.py b/test/test_statistics.py
index 2d9096e38..0fc4fd218 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -378,14 +378,16 @@ def test_gmem_access_counter_mixed():
             "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}",
             [
                 """
-            c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]
+            c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]+x[i,k]
             e[i, k] = g[i,k]*(2+h[i,k])
             """
             ],
             name="mixed", assumptions="n,m,l >= 1")
     knl = lp.add_and_infer_dtypes(knl, dict(
-                a=np.float32, b=np.float32, g=np.float64, h=np.float64))
-    knl = lp.split_iname(knl, "j", 16)
+                a=np.float32, b=np.float32, g=np.float64, h=np.float64,
+                x=np.float32))
+    threads = 16
+    knl = lp.split_iname(knl, "j", threads)
     knl = lp.tag_inames(knl, {"j_inner": "l.0", "j_outer": "g.0"})
 
     poly = get_gmem_access_poly(knl)  # noqa
@@ -396,10 +398,14 @@ def test_gmem_access_counter_mixed():
     f64uniform = poly[
                     (np.dtype(np.float64), 'uniform', 'load')
                     ].eval_with_dict(params)
+    f32uniform = poly[
+                    (np.dtype(np.float32), 'uniform', 'load')
+                    ].eval_with_dict(params)
     f32nonconsec = poly[
                     (np.dtype(np.float32), 'nonconsecutive', 'load')
                     ].eval_with_dict(params)
     assert f64uniform == 2*n*m
+    assert f32uniform == n*m*l/threads
     assert f32nonconsec == 3*n*m*l
 
     f64uniform = poly[
-- 
GitLab