diff --git a/loopy/statistics.py b/loopy/statistics.py index 17c5bd3557bd65eddd2d9a35202a604c552e4e19..e27a0f482885658888c97081e4fc1d97fcd149fd 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -1527,11 +1527,10 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, + access_counter_l(insn.expression) ).with_set_attributes(direction="load") - access_assignee_g = access_counter_g(insn.assignee).with_set_attributes( - direction="store") - - # FIXME: (!!!!) for now, don't count writes to local mem - # (^this is updated in a branch that will be merged soon) + access_assignee = ( + access_counter_g(insn.assignee) + + access_counter_l(insn.assignee) + ).with_set_attributes(direction="store") # use count excluding local index tags for uniform accesses for key, val in six.iteritems(access_expr.count_map): @@ -1542,7 +1541,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, * get_insn_count(knl, insn.id, key.count_granularity)) #currently not counting stride of local mem access - for key, val in six.iteritems(access_assignee_g.count_map): + for key, val in six.iteritems(access_assignee.count_map): access_map = ( access_map diff --git a/test/test_statistics.py b/test/test_statistics.py index b9c7185c21af4782af8fb284e72ac6041d5f98da..c04257ff86bd0e726cce2f1481c55cec0c8275e1 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -925,6 +925,12 @@ def test_all_counters_parallel_matmul(): ].eval_with_dict(params) assert local_mem_l == n*m*ell*2 + local_mem_s = local_mem_map[lp.MemAccess('local', np.dtype(np.float32), + direction='store', + count_granularity=CG.WORKITEM) + ].eval_with_dict(params) + assert local_mem_s == n*m*ell*2/bsize + def test_gather_access_footprint(): knl = lp.make_kernel(