diff --git a/loopy/statistics.py b/loopy/statistics.py index 84b6246c36a87334c10dfbd3728bf9efa958f1e8..9ef292f63c12ef06f553738752d9ea894b14eb0c 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -607,12 +607,12 @@ def get_gmem_access_poly(knl): # for now just counting subscripts cache_holder = CacheHolder() - @memoize_in(cache_holder) + @memoize_in(cache_holder, "insn_count") def get_insn_count(knl, insn_inames, uniform=False): if uniform: from loopy.kernel.data import LocalIndexTag insn_inames = [iname for iname in insn_inames if not - isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)] + isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)] inames_domain = knl.get_inames_domain(insn_inames) domain = (inames_domain.project_out_except( insn_inames, [dim_type.set])) @@ -640,15 +640,13 @@ def get_gmem_access_poly(knl): # for now just counting subscripts for key in subs_expr.dict: poly = ToCountMap({key: subs_expr.dict[key]}) if key[1] == "uniform": - subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames, - uniform=True) + subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames, True) else: subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames) for key in subs_assignee.dict: poly = ToCountMap({key: subs_assignee.dict[key]}) if key[1] == "uniform": - subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames, - uniform=True) + subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames, True) else: subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames)