diff --git a/loopy/statistics.py b/loopy/statistics.py index 992c95a4e9aee58786220197d7b3372c4b046614..e8ff22412a85731699f0caaf6d6153d822eeb6dd 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -606,20 +606,45 @@ def get_gmem_access_poly(knl): # for now just counting subscripts subs_poly = ToCountMap() subscript_counter = GlobalSubscriptCounter(knl) for insn in knl.instructions: - insn_inames = knl.insn_inames(insn) - inames_domain = knl.get_inames_domain(insn_inames) - domain = (inames_domain.project_out_except(insn_inames, [dim_type.set])) + # count subscripts, distinguishing loads and stores subs_expr = subscript_counter(insn.expression) subs_expr = ToCountMap(dict( (key + ("load",), val) for key, val in six.iteritems(subs_expr.dict))) - subs_assignee = subscript_counter(insn.assignee) subs_assignee = ToCountMap(dict( (key + ("store",), val) for key, val in six.iteritems(subs_assignee.dict))) - subs_poly = subs_poly + (subs_expr + subs_assignee)*count(knl, domain) + # get count including local index tags + insn_inames = knl.insn_inames(insn) + inames_domain = knl.get_inames_domain(insn_inames) + domain = (inames_domain.project_out_except(insn_inames, [dim_type.set])) + count_all = count(knl, domain) + + # get count excluding local index tags + from loopy.kernel.data import LocalIndexTag + insn_inames_nonlocal = [iname for iname in insn_inames if not + isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)] + inames_domain_nonlocal = knl.get_inames_domain(insn_inames_nonlocal) + domain_nonlocal = (inames_domain_nonlocal.project_out_except( + insn_inames_nonlocal, [dim_type.set])) + count_nonlocal = count(knl, domain_nonlocal) + + # use count excluding local index tags for uniform accesses + for key in subs_expr.dict: + poly = ToCountMap({key: subs_expr.dict[key]}) + if key[1] == "uniform": + subs_poly = subs_poly + poly*count_nonlocal + else: + subs_poly = subs_poly + poly*count_all + for key in subs_assignee.dict: + poly = ToCountMap({key: subs_assignee.dict[key]}) + if key[1] == "uniform": + subs_poly = subs_poly + poly*count_nonlocal + else: + subs_poly = subs_poly + poly*count_all + #subs_poly = subs_poly + (subs_expr + subs_assignee)*count(knl, domain) return subs_poly.dict diff --git a/test/test_statistics.py b/test/test_statistics.py index 2d9096e381b7c914b14ce74070a4ed64b82e2eff..0fc4fd218ee2d0f8293e36548a1bb741107ff702 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -378,14 +378,16 @@ def test_gmem_access_counter_mixed(): "[n,m,l] -> {[i,k,j]: 0<=i