From 1c7764f032da97a8b70e23bc7c37a9b0341c0d83 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 30 Apr 2013 23:08:00 -0400
Subject: [PATCH] Make bounds checking not fail on nonlinear indices.

---
 loopy/check.py     | 42 ++++++++++++++++++++++++------------------
 setup.py           |  2 +-
 test/test_loopy.py | 23 +++++++++++++++++++++++
 3 files changed, 48 insertions(+), 19 deletions(-)

diff --git a/loopy/check.py b/loopy/check.py
index f6bf3d88b..64d21b425 100644
--- a/loopy/check.py
+++ b/loopy/check.py
@@ -226,6 +226,8 @@ class _AccessCheckMapper(WalkMapper):
         self.insn_id = insn_id
 
     def map_subscript(self, expr):
+        WalkMapper.map_subscript(self, expr)
+
         from pymbolic.primitives import Variable
         assert isinstance(expr.aggregate, Variable)
 
@@ -247,32 +249,36 @@ class _AccessCheckMapper(WalkMapper):
             from loopy.symbolic import get_dependencies, get_access_range
 
             available_vars = set(self.domain.get_var_dict())
-            if (get_dependencies(subscript) <= available_vars
+            if not (get_dependencies(subscript) <= available_vars
                     and get_dependencies(shape) <= available_vars):
+                return
 
-                if len(subscript) != len(shape):
-                    raise RuntimeError("subscript to '%s' in '%s' has the wrong "
-                            "number of indices (got: %d, expected: %d)" % (
-                                expr.aggregate.name, expr,
-                                len(subscript), len(shape)))
+            if len(subscript) != len(shape):
+                raise RuntimeError("subscript to '%s' in '%s' has the wrong "
+                        "number of indices (got: %d, expected: %d)" % (
+                            expr.aggregate.name, expr,
+                            len(subscript), len(shape)))
 
+            try:
                 access_range = get_access_range(self.domain, subscript)
+            except isl.Error:
+                # Likely: index was non-linear, nothing we can do.
+                return
 
-                shape_domain = isl.BasicSet.universe(access_range.get_space())
-                for idim in xrange(len(subscript)):
-                    from loopy.isl_helpers import make_slab
-                    slab = make_slab(
-                            shape_domain.get_space(), (dim_type.in_, idim),
-                            0, shape[idim])
+            shape_domain = isl.BasicSet.universe(access_range.get_space())
+            for idim in xrange(len(subscript)):
+                from loopy.isl_helpers import make_slab
+                slab = make_slab(
+                        shape_domain.get_space(), (dim_type.in_, idim),
+                        0, shape[idim])
 
-                    shape_domain = shape_domain.intersect(slab)
+                shape_domain = shape_domain.intersect(slab)
 
-                if not access_range.is_subset(shape_domain):
-                    raise RuntimeError("'%s' in instruction '%s' "
-                            "accesses out-of-bounds array element"
-                            % (expr, self.insn_id))
+            if not access_range.is_subset(shape_domain):
+                raise RuntimeError("'%s' in instruction '%s' "
+                        "accesses out-of-bounds array element"
+                        % (expr, self.insn_id))
 
-        WalkMapper.map_subscript(self, expr)
 
 def check_bounds(kernel):
     temp_var_names = set(kernel.temporary_variables)
diff --git a/setup.py b/setup.py
index 19c0f32e5..d14749a5f 100644
--- a/setup.py
+++ b/setup.py
@@ -45,7 +45,7 @@ setup(name="loopy",
           "pyopencl>=2013.1",
           "pymbolic>=2013.1",
           "cgen",
-          "islpy>=2013.1"
+          "islpy>=2013.2"
           ],
 
       author="Andreas Kloeckner",
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 8c034d6f5..b26340e39 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -1052,6 +1052,29 @@ def test_arg_shape_guessing(ctx_factory):
     print knl
     print lp.CompiledKernel(ctx, knl).get_highlighted_code()
 
+
+
+
+def test_nonlinear_index(ctx_factory):
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(ctx.devices[0], [
+            "{[i,j]: 0<=i,j<n }",
+            ],
+            """
+                a[i*i] = 17
+                """,
+            [
+                lp.GlobalArg("a", shape="n"),
+                lp.ValueArg("n"),
+                ],
+            assumptions="n>=1")
+
+    print knl
+    print lp.CompiledKernel(ctx, knl).get_highlighted_code()
+
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1:
-- 
GitLab