From f6688709169a430b0c2783630bae772fa563146f Mon Sep 17 00:00:00 2001 From: James Stevens <jdsteve2@porter.cs.illinois.edu> Date: Sun, 7 Feb 2016 14:45:44 -0600 Subject: [PATCH] working on memoizing, get_dram_access_poly currently broken --- loopy/statistics.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/loopy/statistics.py b/loopy/statistics.py index ca018a9c7..84b6246c3 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -28,7 +28,7 @@ import loopy as lp import warnings from islpy import dim_type import islpy as isl -from pytools import memoize_method +from pytools import memoize_in from pymbolic.mapper import CombineMapper from functools import reduce from loopy.kernel.data import Assignment @@ -553,18 +553,6 @@ def sum_ops_to_dtypes(op_poly_dict): return result -@memoize_method -def get_insn_count(knl, insn_inames, uniform=False): - if uniform: - from loopy.kernel.data import LocalIndexTag - insn_inames = [iname for iname in insn_inames if not - isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)] - inames_domain = knl.get_inames_domain(insn_inames) - domain = (inames_domain.project_out_except( - insn_inames, [dim_type.set])) - return count(knl, domain) - - # {{{ get_gmem_access_poly def get_gmem_access_poly(knl): # for now just counting subscripts @@ -613,6 +601,23 @@ def get_gmem_access_poly(knl): # for now just counting subscripts """ from loopy.preprocess import preprocess_kernel, infer_unknown_types + + class CacheHolder(object): + pass + + cache_holder = CacheHolder() + + @memoize_in(cache_holder) + def get_insn_count(knl, insn_inames, uniform=False): + if uniform: + from loopy.kernel.data import LocalIndexTag + insn_inames = [iname for iname in insn_inames if not + isinstance(knl.iname_to_tag.get(iname), LocalIndexTag)] + inames_domain = knl.get_inames_domain(insn_inames) + domain = (inames_domain.project_out_except( + insn_inames, [dim_type.set])) + return count(knl, domain) + knl = infer_unknown_types(knl, expect_completion=True) knl = preprocess_kernel(knl) -- GitLab