From 12452330422ceca19ac384d16b17a2ccc63ecb8a Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 9 Nov 2011 10:10:13 -0500
Subject: [PATCH] Cache result of dim_{max,min}.

---
 MEMO              |  4 ++--
 loopy/__init__.py |  2 +-
 loopy/cse.py      |  3 ++-
 loopy/kernel.py   | 53 +++++++++++++++++++++++++++++++++++++++--------
 4 files changed, 49 insertions(+), 13 deletions(-)

diff --git a/MEMO b/MEMO
index 3111326d3..d80e6519f 100644
--- a/MEMO
+++ b/MEMO
@@ -41,8 +41,6 @@ To-do
 
 - CSE should be more like variable assignment
 
-- dim_max caching
-
 - Fix all tests
 
 - Deal with equality constraints.
@@ -88,6 +86,8 @@ Future ideas
 Dealt with
 ^^^^^^^^^^
 
+- dim_max caching
+
 - Exhaust the search for a no-boost solution first, before looking
   for a schedule with boosts.
 
diff --git a/loopy/__init__.py b/loopy/__init__.py
index a5f1cb632..e7d8ce637 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -168,7 +168,7 @@ def make_kernel(*args, **kwargs):
 
             base_indices, shape = \
                     find_var_base_indices_and_shape_from_inames(
-                            new_domain, assignee_indices)
+                            new_domain, assignee_indices, knl.cache_manager)
 
             new_temp_vars[assignee_name] = TemporaryVariable(
                     name=assignee_name,
diff --git a/loopy/cse.py b/loopy/cse.py
index 46c189f3f..ee22a8b8f 100644
--- a/loopy/cse.py
+++ b/loopy/cse.py
@@ -498,7 +498,8 @@ def realize_cse(kernel, cse_tag, dtype, independent_inames=[],
 
     target_var_base_indices, target_var_shape = \
             find_var_base_indices_and_shape_from_inames(
-                    new_domain, independent_inames)
+                    new_domain, independent_inames,
+                    kernel.cache_manager)
 
     new_temporary_variables = kernel.temporary_variables.copy()
     new_temporary_variables[target_var_name] = TemporaryVariable(
diff --git a/loopy/kernel.py b/loopy/kernel.py
index b301ac985..dfdae0502 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -452,6 +452,8 @@ class LoopKernel(Record):
         workgroup axes to ther sizes, e.g. *{0: 16}* forces axis 0 to be
         length 16.
 
+    :ivar cache_manager:
+
     The following instance variables are only used until :func:`loopy.kernel.make_kernel` is
     finished:
 
@@ -466,7 +468,8 @@ class LoopKernel(Record):
             iname_slab_increments={},
             temporary_variables={},
             local_sizes={},
-            iname_to_tag={}, iname_to_tag_requests=None, cses={}, substitutions={}):
+            iname_to_tag={}, iname_to_tag_requests=None, cses={}, substitutions={},
+            cache_manager=None):
         """
         :arg domain: a :class:`islpy.BasicSet`, or a string parseable to a basic set by the isl.
             Example: "{[i,j]: 0<=i < 10 and 0<= j < 9}"
@@ -475,6 +478,9 @@ class LoopKernel(Record):
 
         import re
 
+        if cache_manager is None:
+            cache_manager = SetOperationCacheManager()
+
         if isinstance(domain, str):
             ctx = isl.Context()
             domain = isl.Set.read_from_str(ctx, domain)
@@ -650,7 +656,8 @@ class LoopKernel(Record):
                 local_sizes=local_sizes,
                 iname_to_tag=iname_to_tag,
                 iname_to_tag_requests=iname_to_tag_requests,
-                cses=cses, substitutions=substitutions)
+                cses=cses, substitutions=substitutions,
+                cache_manager=cache_manager)
 
     def make_unique_instruction_id(self, insns=None, based_on="insn", extra_used_ids=set()):
         if insns is None:
@@ -863,12 +870,14 @@ class LoopKernel(Record):
                 isl.align_spaces(self.assumptions, self.domain)
                 & self.domain)
         lower_bound_pw_aff = (
-                dom_intersect_assumptions
-                .dim_min(self.iname_to_dim[iname][1])
+                self.cache_manager.dim_min(
+                    dom_intersect_assumptions,
+                    self.iname_to_dim[iname][1])
                 .coalesce())
         upper_bound_pw_aff = (
-                dom_intersect_assumptions
-                .dim_max(self.iname_to_dim[iname][1])
+                self.cache_manager.dim_max(
+                    dom_intersect_assumptions,
+                    self.iname_to_dim[iname][1])
                 .coalesce())
 
         class BoundsRecord(Record):
@@ -1035,14 +1044,14 @@ class LoopKernel(Record):
 
 
 
-def find_var_base_indices_and_shape_from_inames(domain, inames):
+def find_var_base_indices_and_shape_from_inames(domain, inames, cache_manager):
     base_indices = []
     shape = []
 
     iname_to_dim = domain.get_space().get_var_dict()
     for iname in inames:
-        lower_bound_pw_aff = domain.dim_min(iname_to_dim[iname][1])
-        upper_bound_pw_aff = domain.dim_max(iname_to_dim[iname][1])
+        lower_bound_pw_aff = cache_manager.dim_min(domain, iname_to_dim[iname][1])
+        upper_bound_pw_aff = cache_manager.dim_max(domain, iname_to_dim[iname][1])
 
         from loopy.isl_helpers import static_max_of_pw_aff
         from loopy.symbolic import pw_aff_to_expr
@@ -1078,4 +1087,30 @@ def get_dot_dependency_graph(kernel, iname_cluster=False, iname_edge=True):
 
 
 
+class SetOperationCacheManager:
+    def __init__(self):
+        # mapping: set hash -> [(set, op, args, result)]
+        self.cache = {}
+
+    def op(self, set, op, args):
+        hashval = hash(set)
+        bucket = self.cache.setdefault(hashval, [])
+
+        for bkt_set, bkt_op, bkt_args, result  in bucket:
+            if set.plain_is_equal(bkt_set) and op == bkt_op and args == bkt_args:
+                return result
+
+        result = getattr(set, op)(*args)
+        bucket.append((set, op, args, result))
+        return result
+
+    def dim_min(self, set, *args):
+        return self.op(set, "dim_min", args)
+
+    def dim_max(self, set, *args):
+        return self.op(set, "dim_max", args)
+
+
+
+
 # vim: foldmethod=marker
-- 
GitLab