diff --git a/loopy/statistics.py b/loopy/statistics.py
index 17c5bd3557bd65eddd2d9a35202a604c552e4e19..e27a0f482885658888c97081e4fc1d97fcd149fd 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -1527,11 +1527,10 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
                     + access_counter_l(insn.expression)
                     ).with_set_attributes(direction="load")
 
-            access_assignee_g = access_counter_g(insn.assignee).with_set_attributes(
-                    direction="store")
-
-            # FIXME: (!!!!) for now, don't count writes to local mem
-            # (^this is updated in a branch that will be merged soon)
+            access_assignee = (
+                    access_counter_g(insn.assignee)
+                    + access_counter_l(insn.assignee)
+                    ).with_set_attributes(direction="store")
 
             # use count excluding local index tags for uniform accesses
             for key, val in six.iteritems(access_expr.count_map):
@@ -1542,7 +1541,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
                         * get_insn_count(knl, insn.id, key.count_granularity))
                 #currently not counting stride of local mem access
 
-            for key, val in six.iteritems(access_assignee_g.count_map):
+            for key, val in six.iteritems(access_assignee.count_map):
 
                 access_map = (
                         access_map
diff --git a/test/test_statistics.py b/test/test_statistics.py
index b9c7185c21af4782af8fb284e72ac6041d5f98da..c04257ff86bd0e726cce2f1481c55cec0c8275e1 100644
--- a/test/test_statistics.py
+++ b/test/test_statistics.py
@@ -925,6 +925,12 @@ def test_all_counters_parallel_matmul():
                                 ].eval_with_dict(params)
     assert local_mem_l == n*m*ell*2
 
+    local_mem_s = local_mem_map[lp.MemAccess('local', np.dtype(np.float32),
+                                             direction='store',
+                                             count_granularity=CG.WORKITEM)
+                                ].eval_with_dict(params)
+    assert local_mem_s == n*m*ell*2/bsize
+
 
 def test_gather_access_footprint():
     knl = lp.make_kernel(