From cadc7fce17dd7ffb8937d1b0720167d5f3dfd581 Mon Sep 17 00:00:00 2001 From: jdsteve2 Date: Sun, 25 Feb 2018 00:33:22 -0600 Subject: [PATCH] no longer ignoring local stores in mem access counting --- loopy/statistics.py | 11 +++++------ test/test_statistics.py | 6 ++++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index 17c5bd355..e27a0f482 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -1527,11 +1527,10 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, + access_counter_l(insn.expression) ).with_set_attributes(direction="load") - access_assignee_g = access_counter_g(insn.assignee).with_set_attributes( - direction="store") - - # FIXME: (!!!!) for now, don't count writes to local mem - # (^this is updated in a branch that will be merged soon) + access_assignee = ( + access_counter_g(insn.assignee) + + access_counter_l(insn.assignee) + ).with_set_attributes(direction="store") # use count excluding local index tags for uniform accesses for key, val in six.iteritems(access_expr.count_map): @@ -1542,7 +1541,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False, * get_insn_count(knl, insn.id, key.count_granularity)) #currently not counting stride of local mem access - for key, val in six.iteritems(access_assignee_g.count_map): + for key, val in six.iteritems(access_assignee.count_map): access_map = ( access_map diff --git a/test/test_statistics.py b/test/test_statistics.py index b9c7185c2..c04257ff8 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -925,6 +925,12 @@ def test_all_counters_parallel_matmul(): ].eval_with_dict(params) assert local_mem_l == n*m*ell*2 + local_mem_s = local_mem_map[lp.MemAccess('local', np.dtype(np.float32), + direction='store', + count_granularity=CG.WORKITEM) + ].eval_with_dict(params) + assert local_mem_s == n*m*ell*2/bsize + def test_gather_access_footprint(): knl = lp.make_kernel( -- GitLab