diff --git a/loopy/statistics.py b/loopy/statistics.py index 992c95a4e9aee58786220197d7b3372c4b046614..84b6246c36a87334c10dfbd3728bf9efa958f1e8 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -28,6 +28,7 @@ import loopy as lp import warnings from islpy import dim_type import islpy as isl +from pytools import memoize_in from pymbolic.mapper import CombineMapper from functools import reduce from loopy.kernel.data import Assignment @@ -600,26 +601,57 @@ def get_gmem_access_poly(knl): # for now just counting subscripts """ from loopy.preprocess import preprocess_kernel, infer_unknown_types + + class CacheHolder(object): + pass + + cache_holder = CacheHolder() + + @memoize_in(cache_holder) + def get_insn_count(knl, insn_inames, uniform=False): + if uniform: + from loopy.kernel.data import LocalIndexTag + insn_inames = [iname for iname in insn_inames if not + isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)] + inames_domain = knl.get_inames_domain(insn_inames) + domain = (inames_domain.project_out_except( + insn_inames, [dim_type.set])) + return count(knl, domain) + knl = infer_unknown_types(knl, expect_completion=True) knl = preprocess_kernel(knl) subs_poly = ToCountMap() subscript_counter = GlobalSubscriptCounter(knl) for insn in knl.instructions: - insn_inames = knl.insn_inames(insn) - inames_domain = knl.get_inames_domain(insn_inames) - domain = (inames_domain.project_out_except(insn_inames, [dim_type.set])) + # count subscripts, distinguishing loads and stores subs_expr = subscript_counter(insn.expression) subs_expr = ToCountMap(dict( (key + ("load",), val) for key, val in six.iteritems(subs_expr.dict))) - subs_assignee = subscript_counter(insn.assignee) subs_assignee = ToCountMap(dict( (key + ("store",), val) for key, val in six.iteritems(subs_assignee.dict))) - subs_poly = subs_poly + (subs_expr + subs_assignee)*count(knl, domain) + insn_inames = knl.insn_inames(insn) + + # use count excluding local index tags for uniform accesses + for key in subs_expr.dict: + poly = ToCountMap({key: subs_expr.dict[key]}) + if key[1] == "uniform": + subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames, + uniform=True) + else: + subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames) + for key in subs_assignee.dict: + poly = ToCountMap({key: subs_assignee.dict[key]}) + if key[1] == "uniform": + subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames, + uniform=True) + else: + subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames) + return subs_poly.dict diff --git a/test/test_statistics.py b/test/test_statistics.py index 2d9096e381b7c914b14ce74070a4ed64b82e2eff..0fc4fd218ee2d0f8293e36548a1bb741107ff702 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -378,14 +378,16 @@ def test_gmem_access_counter_mixed(): "[n,m,l] -> {[i,k,j]: 0<=i<n and 0<=k<m and 0<=j<l}", [ """ - c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k] + c[i, j, k] = a[i,j,k]*b[i,j,k]/3.0+a[i,j,k]+x[i,k] e[i, k] = g[i,k]*(2+h[i,k]) """ ], name="mixed", assumptions="n,m,l >= 1") knl = lp.add_and_infer_dtypes(knl, dict( - a=np.float32, b=np.float32, g=np.float64, h=np.float64)) - knl = lp.split_iname(knl, "j", 16) + a=np.float32, b=np.float32, g=np.float64, h=np.float64, + x=np.float32)) + threads = 16 + knl = lp.split_iname(knl, "j", threads) knl = lp.tag_inames(knl, {"j_inner": "l.0", "j_outer": "g.0"}) poly = get_gmem_access_poly(knl) # noqa @@ -396,10 +398,14 @@ def test_gmem_access_counter_mixed(): f64uniform = poly[ (np.dtype(np.float64), 'uniform', 'load') ].eval_with_dict(params) + f32uniform = poly[ + (np.dtype(np.float32), 'uniform', 'load') + ].eval_with_dict(params) f32nonconsec = poly[ (np.dtype(np.float32), 'nonconsecutive', 'load') ].eval_with_dict(params) assert f64uniform == 2*n*m + assert f32uniform == n*m*l/threads assert f32nonconsec == 3*n*m*l f64uniform = poly[