From f6688709169a430b0c2783630bae772fa563146f Mon Sep 17 00:00:00 2001
From: James Stevens <jdsteve2@porter.cs.illinois.edu>
Date: Sun, 7 Feb 2016 14:45:44 -0600
Subject: [PATCH] working on memoizing, get_dram_access_poly currently broken

---
 loopy/statistics.py | 31 ++++++++++++++++++-------------
 1 file changed, 18 insertions(+), 13 deletions(-)

diff --git a/loopy/statistics.py b/loopy/statistics.py
index ca018a9c7..84b6246c3 100755
--- a/loopy/statistics.py
+++ b/loopy/statistics.py
@@ -28,7 +28,7 @@ import loopy as lp
 import warnings
 from islpy import dim_type
 import islpy as isl
-from pytools import memoize_method
+from pytools import memoize_in
 from pymbolic.mapper import CombineMapper
 from functools import reduce
 from loopy.kernel.data import Assignment
@@ -553,18 +553,6 @@ def sum_ops_to_dtypes(op_poly_dict):
     return result
 
 
-@memoize_method
-def get_insn_count(knl, insn_inames, uniform=False):
-    if uniform:
-        from loopy.kernel.data import LocalIndexTag
-        insn_inames = [iname for iname in insn_inames if not
-                       isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)] 
-    inames_domain = knl.get_inames_domain(insn_inames)
-    domain = (inames_domain.project_out_except(
-                            insn_inames, [dim_type.set]))
-    return count(knl, domain)
-
-
 # {{{ get_gmem_access_poly
 def get_gmem_access_poly(knl):  # for now just counting subscripts
 
@@ -613,6 +601,23 @@ def get_gmem_access_poly(knl):  # for now just counting subscripts
     """
 
     from loopy.preprocess import preprocess_kernel, infer_unknown_types
+
+    class CacheHolder(object):
+        pass
+
+    cache_holder = CacheHolder()
+
+    @memoize_in(cache_holder)
+    def get_insn_count(knl, insn_inames, uniform=False):
+        if uniform:
+            from loopy.kernel.data import LocalIndexTag
+            insn_inames = [iname for iname in insn_inames if not
+                           isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)] 
+        inames_domain = knl.get_inames_domain(insn_inames)
+        domain = (inames_domain.project_out_except(
+                                insn_inames, [dim_type.set]))
+        return count(knl, domain)
+
     knl = infer_unknown_types(knl, expect_completion=True)
     knl = preprocess_kernel(knl)
 
-- 
GitLab