From cb15587b71e6a7805edc1bee7feddf0b58584497 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 12 Jul 2015 22:05:49 -0500
Subject: [PATCH] Bring (mildly fixed) perf tweak for dim_{min,max} over from
 islpy

---
 loopy/isl_helpers.py  | 123 ++++++++++++++++++++++++++++++++++++++++++
 loopy/kernel/tools.py |   7 +--
 2 files changed, 127 insertions(+), 3 deletions(-)

diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py
index e8409a9ce..9889fa12e 100644
--- a/loopy/isl_helpers.py
+++ b/loopy/isl_helpers.py
@@ -384,4 +384,127 @@ def obj_involves_variable(obj, var_name):
 
     return False
 
+
+# {{{ performance tweak for dim_{min,max}: project first
+
+def _runs_in_integer_set(s, max_int=None):
+    if not s:
+        return
+
+    if max_int is None:
+        max_int = max(s)
+
+    i = 0
+    while i < max_int:
+        if i in s:
+            start = i
+
+            i += 1
+            while i < max_int and i in s:
+                i += 1
+
+            end = i
+
+            yield (start, end-start)
+
+        else:
+            i += 1
+
+
+class TooManyInteractingDims(Exception):
+    pass
+
+
+def _find_aff_dims(aff, dim_types_and_gen_dim_types):
+    result = []
+
+    for dt, gen_dt in dim_types_and_gen_dim_types:
+        for i in range(aff.dim(dt)):
+            if not aff.get_coefficient_val(dt, i).is_zero():
+                result.append((gen_dt, i))
+
+    result = set(result)
+
+    for i in range(aff.dim(dim_type.div)):
+        result.update(_find_aff_dims(
+            aff.get_div(i),
+            dim_types_and_gen_dim_types))
+
+    return result
+
+
+def _transitive_closure(graph_dict):
+    pass
+
+
+def _find_noninteracting_dims(obj, dt, idx, other_dt, stop_at=6):
+    if isinstance(obj, isl.BasicSet):
+        basics = [obj]
+    elif isinstance(obj, isl.Set):
+        basics = obj.get_basic_sets()
+    else:
+        raise TypeError("unsupported arg type '%s'" % type(obj))
+
+    connections = []
+    for bs in basics:
+        for c in bs.get_constraints():
+            conn = _find_aff_dims(
+                    c.get_aff(),
+                    [(dim_type.param, dim_type.param), (dim_type.in_, dim_type.set)])
+            if len(conn) > 1:
+                connections.append(conn)
+
+    interacting = set([(dt, idx)])
+
+    while True:
+        changed_something = False
+
+        # Compute the connected component near (dt, idx) by fixed point iteration
+
+        for conn in connections:
+            prev_len = len(interacting)
+
+            overlap = interacting & conn
+            if overlap:
+                interacting.update(conn)
+
+            if len(interacting) != prev_len:
+                changed_something = True
+
+            if len(interacting) >= stop_at:
+                raise TooManyInteractingDims()
+
+        if not changed_something:
+            break
+
+    return set(range(obj.dim(other_dt))) - set(
+            idx for dt, idx in interacting
+            if dt == other_dt)
+
+
+def _eliminate_noninteracting(obj, dt, idx, other_dt):
+    try:
+        nonint = _find_noninteracting_dims(obj, dt, idx, other_dt)
+
+    except TooManyInteractingDims:
+        return obj
+
+    for first, n in _runs_in_integer_set(nonint):
+        obj = obj.eliminate(other_dt, first, n)
+
+    return obj
+
+
+def dim_min_with_elimination(obj, idx):
+    obj_elim = _eliminate_noninteracting(obj, dim_type.out, idx, dim_type.param)
+    return obj_elim.dim_min(idx)
+
+
+def dim_max_with_elimination(obj, idx):
+    obj_elim = _eliminate_noninteracting(obj, dim_type.out, idx, dim_type.param)
+    return obj_elim.dim_max(idx)
+
+# }}}
+
+
 # vim: foldmethod=marker
diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py
index d1093c01a..b59c40731 100644
--- a/loopy/kernel/tools.py
+++ b/loopy/kernel/tools.py
@@ -28,7 +28,6 @@ THE SOFTWARE.
 
 
 import numpy as np
-import islpy as isl
 from islpy import dim_type
 from loopy.diagnostic import LoopyError
 
@@ -257,10 +256,12 @@ class SetOperationCacheManager:
         return result
 
     def dim_min(self, set, *args):
-        return self.op(set, "dim_min", isl.dim_min_with_elimination, args)
+        from loopy.isl_helpers import dim_min_with_elimination
+        return self.op(set, "dim_min", dim_min_with_elimination, args)
 
     def dim_max(self, set, *args):
-        return self.op(set, "dim_max", isl.dim_max_with_elimination, args)
+        from loopy.isl_helpers import dim_max_with_elimination
+        return self.op(set, "dim_max", dim_max_with_elimination, args)
 
     def base_index_and_length(self, set, iname, context=None):
         if not isinstance(iname, int):
-- 
GitLab