From f85486e087a3209f34cc3a852cd492261f35a6f0 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 12 Aug 2012 14:01:08 -0400
Subject: [PATCH] A few more multi-domain fixes.

---
 loopy/__init__.py       |  3 +++
 loopy/codegen/bounds.py |  6 ++++--
 loopy/isl_helpers.py    |  5 ++---
 loopy/kernel.py         |  5 ++---
 test/test_loopy.py      | 10 ++++++++--
 5 files changed, 19 insertions(+), 10 deletions(-)

diff --git a/loopy/__init__.py b/loopy/__init__.py
index fe7f8859b..2f8fd8f5d 100644
--- a/loopy/__init__.py
+++ b/loopy/__init__.py
@@ -80,6 +80,9 @@ def split_dimension(kernel, split_iname, inner_length,
         inner_iname = split_iname+"_inner"
 
     def process_set(s):
+        if split_iname not in s.get_var_dict():
+            return s
+
         outer_var_nr = s.dim(dim_type.set)
         inner_var_nr = s.dim(dim_type.set)+1
 
diff --git a/loopy/codegen/bounds.py b/loopy/codegen/bounds.py
index 0237295b2..1ec320bef 100644
--- a/loopy/codegen/bounds.py
+++ b/loopy/codegen/bounds.py
@@ -22,6 +22,7 @@ def get_bounds_constraints(set, iname, admissible_inames, allow_parameters):
             elim_type.append(dim_type.param)
 
         set = set.eliminate_except(admissible_inames, elim_type)
+        set = set.compute_divs()
 
     basic_sets = set.get_basic_sets()
     if len(basic_sets) > 1:
@@ -106,11 +107,12 @@ def constraint_to_code(ccm, cns):
     return "%s %s 0" % (ccm(constraint_to_expr(cns)), comp_op)
 
 def filter_necessary_constraints(implemented_domain, constraints):
-    space = implemented_domain.get_space()
     return [cns
         for cns in constraints
         if not implemented_domain.is_subset(
-            isl.Set.universe(space).add_constraint(cns))]
+            isl.align_spaces(
+                isl.BasicSet.universe(cns.get_space()).add_constraint(cns),
+                implemented_domain))]
 
 def generate_bounds_checks(domain, check_inames, implemented_domain):
     """Will not overapproximate."""
diff --git a/loopy/isl_helpers.py b/loopy/isl_helpers.py
index 3eb1ccba2..dab7648af 100644
--- a/loopy/isl_helpers.py
+++ b/loopy/isl_helpers.py
@@ -105,11 +105,10 @@ def dump_space(ls):
 def make_slab(space, iname, start, stop):
     zero = isl.Aff.zero_on_domain(space)
 
-    from islpy import align_spaces
     if isinstance(start, (isl.Aff, isl.PwAff)):
-        start = align_spaces(pw_aff_to_aff(start), zero)
+        start, zero = isl.align_two(pw_aff_to_aff(start), zero)
     if isinstance(stop, (isl.Aff, isl.PwAff)):
-        stop = align_spaces(pw_aff_to_aff(stop), zero)
+        stop, zero = isl.align_two(pw_aff_to_aff(stop), zero)
 
     if isinstance(start, int): start = zero + start
     if isinstance(stop, int): stop = zero + stop
diff --git a/loopy/kernel.py b/loopy/kernel.py
index 24ca824e7..36575b103 100644
--- a/loopy/kernel.py
+++ b/loopy/kernel.py
@@ -1235,7 +1235,8 @@ class LoopKernel(Record):
         d_var_dict = domain.get_var_dict()
 
         dom_intersect_assumptions = (
-                isl.align_spaces(self.assumptions, domain) & domain)
+                isl.align_spaces(self.assumptions, domain, obj_bigger_ok=True)
+                & domain)
         lower_bound_pw_aff = (
                 self.cache_manager.dim_min(
                     dom_intersect_assumptions,
@@ -1614,6 +1615,4 @@ class SetOperationCacheManager:
 
 
 
-
-
 # vim: foldmethod=marker
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 1de7eeb03..37d07f04e 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -368,7 +368,10 @@ def test_dependent_loop_bounds_2(ctx_factory):
     ctx = ctx_factory()
 
     knl = lp.make_kernel(ctx.devices[0],
-            "[n,row_len] -> {[i,jj]: 0<=i<n and 0<=jj<row_len}",
+            [
+                "{[i]: 0<=i<n}",
+                "{[jj]: 0<=jj<row_len}",
+                ],
             [
                 "<> row_start = a_rowstarts[i]",
                 "<> row_len = a_rowstarts[i+1] - row_start",
@@ -400,7 +403,10 @@ def test_dependent_loop_bounds_3(ctx_factory):
     ctx = ctx_factory()
 
     knl = lp.make_kernel(ctx.devices[0],
-            "[n,row_len] -> {[i,j]: 0<=i<n and 0<=j<row_len}",
+            [
+                "{[i]: 0<=i<n}",
+                "{[jj]: 0<=jj<row_len}",
+                ],
             [
                 "<> row_len = a_row_lengths[i]",
                 "a[i,j] = 1",
-- 
GitLab