From cadc7fce17dd7ffb8937d1b0720167d5f3dfd581 Mon Sep 17 00:00:00 2001
From: jdsteve2 <jdsteve2@illinois.edu>
Date: Sun, 25 Feb 2018 00:33:22 -0600
Subject: [PATCH] no longer ignoring local stores in mem access counting

---
 loopy/statistics.py     | 11 +++++------
 test/test_statistics.py |  6 ++++++
 2 files changed, 11 insertions(+), 6 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 17c5bd355..e27a0f482 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -1527,11 +1527,10 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
                     + access_counter_l(insn.expression)
                     ).with_set_attributes(direction="load")
 
-            access_assignee_g = access_counter_g(insn.assignee).with_set_attributes(
-                    direction="store")
-
-            # FIXME: (!!!!) for now, don't count writes to local mem
-            # (^this is updated in a branch that will be merged soon)
+            access_assignee = (
+                    access_counter_g(insn.assignee)
+                    + access_counter_l(insn.assignee)
+                    ).with_set_attributes(direction="store")
 
             # use count excluding local index tags for uniform accesses
             for key, val in six.iteritems(access_expr.count_map):
@@ -1542,7 +1541,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
                         * get_insn_count(knl, insn.id, key.count_granularity))
                 #currently not counting stride of local mem access
 
-            for key, val in six.iteritems(access_assignee_g.count_map):
+            for key, val in six.iteritems(access_assignee.count_map):
 
                 access_map = (
                         access_map
diff --git a/test/test_statistics.py b/test/test_statistics.py
index b9c7185c2..c04257ff8 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -925,6 +925,12 @@ def test_all_counters_parallel_matmul():
                                 ].eval_with_dict(params)
     assert local_mem_l == n*m*ell*2
 
+    local_mem_s = local_mem_map[lp.MemAccess('local', np.dtype(np.float32),
+                                             direction='store',
+                                             count_granularity=CG.WORKITEM)
+                                ].eval_with_dict(params)
+    assert local_mem_s == n*m*ell*2/bsize
+
 
 def test_gather_access_footprint():
     knl = lp.make_kernel(
-- 
GitLab