From 7f3442f59dc956d59f3efc44d66737ff7e848a6c Mon Sep 17 00:00:00 2001 From: James Stevens Date: Sat, 30 Jan 2016 23:51:38 -0600 Subject: [PATCH] uniform loads and stores now only counted once per thread group --- loopy/statistics.py | 35 ++++++++++++++++++++++++++++++----- test/test_statistics.py | 12 +++++++++--- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index 992c95a4e..e8ff22412 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -606,20 +606,45 @@ def get_gmem_access_poly(knl): # for now just counting subscripts subs_poly = ToCountMap() subscript_counter = GlobalSubscriptCounter(knl) for insn in knl.instructions: - insn_inames = knl.insn_inames(insn) - inames_domain = knl.get_inames_domain(insn_inames) - domain = (inames_domain.project_out_except(insn_inames, [dim_type.set])) + # count subscripts, distinguishing loads and stores subs_expr = subscript_counter(insn.expression) subs_expr = ToCountMap(dict( (key + ("load",), val) for key, val in six.iteritems(subs_expr.dict))) - subs_assignee = subscript_counter(insn.assignee) subs_assignee = ToCountMap(dict( (key + ("store",), val) for key, val in six.iteritems(subs_assignee.dict))) - subs_poly = subs_poly + (subs_expr + subs_assignee)*count(knl, domain) + # get count including local index tags + insn_inames = knl.insn_inames(insn) + inames_domain = knl.get_inames_domain(insn_inames) + domain = (inames_domain.project_out_except(insn_inames, [dim_type.set])) + count_all = count(knl, domain) + + # get count excluding local index tags + from loopy.kernel.data import LocalIndexTag + insn_inames_nonlocal = [iname for iname in insn_inames if not + isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)] + inames_domain_nonlocal = knl.get_inames_domain(insn_inames_nonlocal) + domain_nonlocal = (inames_domain_nonlocal.project_out_except( + insn_inames_nonlocal, [dim_type.set])) + count_nonlocal = count(knl, domain_nonlocal) + + # use count excluding local index tags for uniform accesses + for key in subs_expr.dict: + poly = ToCountMap({key: subs_expr.dict[key]}) + if key[1] == "uniform": + subs_poly = subs_poly + poly*count_nonlocal + else: + subs_poly = subs_poly + poly*count_all + for key in subs_assignee.dict: + poly = ToCountMap({key: subs_assignee.dict[key]}) + if key[1] == "uniform": + subs_poly = subs_poly + poly*count_nonlocal + else: + subs_poly = subs_poly + poly*count_all + #subs_poly = subs_poly + (subs_expr + subs_assignee)*count(knl, domain) return subs_poly.dict diff --git a/test/test_statistics.py b/test/test_statistics.py index 2d9096e38..0fc4fd218 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -378,14 +378,16 @@ def test_gmem_access_counter_mixed(): "[n,m,l] -> {[i,k,j]: 0<=i