From 8a74723829db9c323c462e95b1d064ad6598325b Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 8 Feb 2016 15:35:33 -0600
Subject: [PATCH] Finish up uniform load counting

---
 loopy/statistics.py | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index 84b6246c3..9ef292f63 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -607,12 +607,12 @@ def get_gmem_access_poly(knl):  # for now just counting subscripts
 
     cache_holder = CacheHolder()
 
-    @memoize_in(cache_holder)
+    @memoize_in(cache_holder, "insn_count")
     def get_insn_count(knl, insn_inames, uniform=False):
         if uniform:
             from loopy.kernel.data import LocalIndexTag
             insn_inames = [iname for iname in insn_inames if not
-                           isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)] 
+                           isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)]
         inames_domain = knl.get_inames_domain(insn_inames)
         domain = (inames_domain.project_out_except(
                                 insn_inames, [dim_type.set]))
@@ -640,15 +640,13 @@ def get_gmem_access_poly(knl):  # for now just counting subscripts
         for key in subs_expr.dict:
             poly = ToCountMap({key: subs_expr.dict[key]})
             if key[1] == "uniform":
-                subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames,
-                                                            uniform=True)
+                subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames, True)
             else:
                 subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames)
         for key in subs_assignee.dict:
             poly = ToCountMap({key: subs_assignee.dict[key]})
             if key[1] == "uniform":
-                subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames,
-                                                            uniform=True)
+                subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames, True)
             else:
                 subs_poly = subs_poly + poly*get_insn_count(knl, insn_inames)
 
-- 
GitLab