From e2e4372a81087bc8bb111d0ed32785b1c656cc0c Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 30 Aug 2012 16:46:46 -0400
Subject: [PATCH] Centralize convexification. Better code for
 get_bounds_checks.

---
 MEMO                    |  4 ++--
 loopy/codegen/bounds.py | 39 ++++++++++++++++-----------------------
 loopy/cse.py            | 29 ++---------------------------
 loopy/isl_helpers.py    | 38 ++++++++++++++++++++++++++++++++++++++
 4 files changed, 58 insertions(+), 52 deletions(-)

diff --git a/MEMO b/MEMO
index 75d421f53..7d1fa43fd 100644
--- a/MEMO
+++ b/MEMO
@@ -97,14 +97,14 @@ Future ideas
 
 - Try, fix indirect addressing
 
-- Use gists (why do disjoint sets arise?)
-
 - Nested slab decomposition (in conjunction with conditional hoisting) could
   generate nested conditional code.
 
 Dealt with
 ^^^^^^^^^^
 
+- Use gists (why do disjoint sets arise?)
+
 - Automatically verify that all array access is within bounds.
 
 - : (as in, Matlab full-slice) in prefetches
diff --git a/loopy/codegen/bounds.py b/loopy/codegen/bounds.py
index 9d9005616..6577e4215 100644
--- a/loopy/codegen/bounds.py
+++ b/loopy/codegen/bounds.py
@@ -105,43 +105,36 @@ def constraint_to_code(ccm, cns):
     from loopy.symbolic import constraint_to_expr
     return "%s %s 0" % (ccm(constraint_to_expr(cns), 'i'), comp_op)
 
-def filter_necessary_constraints(implemented_domain, constraints):
-    return [cns
-        for cns in constraints
-        if not implemented_domain.is_subset(
-            isl.align_spaces(
-                isl.BasicSet.universe(cns.get_space()).add_constraint(cns),
-                implemented_domain))]
-
 def generate_bounds_checks(domain, check_inames, implemented_domain):
-    """Will not overapproximate if check_inames consists of all inames in the domain."""
+    """Will not overapproximate."""
 
-    if len(check_inames) == domain.dim(dim_type.set):
-        assert check_inames == frozenset(domain.get_var_names(dim_type.set))
-    else:
-        domain = (domain
-                .eliminate_except(check_inames, [dim_type.set])
-                .remove_divs())
+    domain = (domain
+            .eliminate_except(check_inames, [dim_type.set])
+            .compute_divs())
 
     if isinstance(domain, isl.Set):
         bsets = domain.get_basic_sets()
-        if len(bsets) == 1:
-            domain_bset, = bsets
-        else:
+        if len(bsets) != 1:
             domain = domain.coalesce()
             bsets = domain.get_basic_sets()
-            if len(bsets) == 1:
+            if len(bsets) != 1:
                 raise RuntimeError("domain of inames '%s' projected onto '%s' "
                         "did not reduce to a single conjunction"
                         % (", ".join(domain.get_var_names(dim_type.set)),
                             check_inames))
+
+        domain, = bsets
     else:
-        domain_bset = domain
+        domain = domain
+
+    domain = domain.remove_redundancies()
+    domain = isl.Set.from_basic_set(domain)
+    domain = isl.align_spaces(domain, implemented_domain)
 
-    domain_bset = domain_bset.remove_redundancies()
+    result = domain.gist(implemented_domain)
 
-    return filter_necessary_constraints(
-            implemented_domain, domain_bset.get_constraints())
+    from loopy.isl_helpers import convexify
+    return convexify(result).get_constraints()
 
 def wrap_in_bounds_checks(ccm, domain, check_inames, implemented_domain, stmt):
     bounds_checks = generate_bounds_checks(
diff --git a/loopy/cse.py b/loopy/cse.py
index 08341a2b8..3a7e6e8b8 100644
--- a/loopy/cse.py
+++ b/loopy/cse.py
@@ -650,33 +650,8 @@ def precompute(kernel, subst_use, dtype, sweep_inames=[],
                             storage_axis_names, storage_axis_sources,
                             sweep_inames, invocation_descriptors)
 
-
-    # {{{ try a few ways to get new_domain to be convex
-
-    if len(new_domain.get_basic_sets()) > 1:
-        hull_new_domain = new_domain.simple_hull()
-        if isl.Set.from_basic_set(hull_new_domain) <= new_domain:
-            new_domain = hull_new_domain
-
-    new_domain = new_domain.coalesce()
-
-    if len(new_domain.get_basic_sets()) > 1:
-        hull_new_domain = new_domain.simple_hull()
-        if isl.Set.from_basic_set(hull_new_domain) <= new_domain:
-            new_domain = hull_new_domain
-
-    if isinstance(new_domain, isl.Set):
-        dom_bsets = new_domain.get_basic_sets()
-        if len(dom_bsets) > 1:
-            print "PIECES:"
-            for dbs in dom_bsets:
-                print "  %s" % (isl.Set.from_basic_set(dbs).gist(new_domain))
-            raise NotImplementedError("Substitution '%s' yielded a non-convex footprint"
-                    % subst_name)
-
-        new_domain, = dom_bsets
-
-    # }}}
+    from loopy.isl_helpers import convexify
+    new_domain = convexify(new_domain)
 
     # {{{ set up compute insn
 
diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py
index c0e0b9708..b67411267 100644
--- a/loopy/isl_helpers.py
+++ b/loopy/isl_helpers.py
@@ -277,6 +277,44 @@ def is_nonnegative(expr, over_set):
 
 
 
+def convexify(domain):
+    """Try a few ways to get *domain* to be a BasicSet, i.e.
+    explicitly convex.
+    """
+
+    if isinstance(domain, isl.BasicSet):
+        return domain
+
+    dom_bsets = domain.get_basic_sets()
+    if len(dom_bsets) == 1:
+        domain, = dom_bsets
+        return domain
+
+    hull_domain = domain.simple_hull()
+    if isl.Set.from_basic_set(hull_domain) <= domain:
+        return hull_domain
+
+    domain = domain.coalesce()
+
+    dom_bsets = domain.get_basic_sets()
+    if len(domain.get_basic_sets()) == 1:
+        domain, = dom_bsets
+        return domain
+
+    hull_domain = domain.simple_hull()
+    if isl.Set.from_basic_set(hull_domain) <= domain:
+        return hull_domain
+
+    dom_bsets = domain.get_basic_sets()
+    assert len(dom_bsets) > 1
+
+    print "PIECES:"
+    for dbs in dom_bsets:
+        print "  %s" % (isl.Set.from_basic_set(dbs).gist(domain))
+    raise NotImplementedError("Could not find convex representation of set")
+
+
+
 
 
 # vim: foldmethod=marker
-- 
GitLab