From 7129db53f6e93c13a88bb75d6a37c27d6e5010a7 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 11 Jul 2013 19:17:58 -0400
Subject: [PATCH] Make modulo in indices less crashy

---
 loopy/check.py           |  3 +++
 loopy/kernel/creation.py |  5 ++++-
 test/test_loopy.py       | 23 ++++++++++++++++++++++-
 3 files changed, 29 insertions(+), 2 deletions(-)

diff --git a/loopy/check.py b/loopy/check.py
index a71bad222..36a50df48 100644
--- a/loopy/check.py
+++ b/loopy/check.py
@@ -253,6 +253,9 @@ class _AccessCheckMapper(WalkMapper):
             except isl.Error:
                 # Likely: index was non-linear, nothing we can do.
                 return
+            except TypeError:
+                # Likely: index was non-linear, nothing we can do.
+                return
 
             shape_domain = isl.BasicSet.universe(access_range.get_space())
             for idim in xrange(len(subscript)):
diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py
index ab2b2eeb9..1f8d0f38c 100644
--- a/loopy/kernel/creation.py
+++ b/loopy/kernel/creation.py
@@ -784,7 +784,10 @@ def guess_arg_shape_if_requested(kernel, default_order):
             except TypeError, e:
                 from loopy.diagnostic import LoopyError
                 raise LoopyError(
-                        "failed to find access range for argument '%s': %s"
+                        "Failed to (automatically, as requested) find "
+                        "shape/strides for argument '%s'. "
+                        "Specifying the shape manually should get rid of this. "
+                        "The following error occurred: %s"
                         % (arg.name, str(e)))
 
             if armap.access_range is None:
diff --git a/test/test_loopy.py b/test/test_loopy.py
index 446de00db..5bccdd5dc 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -1282,10 +1282,31 @@ def test_split_reduction(ctx_factory):
                 "..."])
 
     knl = lp.split_reduction_outward(knl, "j,k")
-    print knl
     # FIXME: finish test
 
 
+def test_modulo_indexing(ctx_factory):
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel(ctx.devices[0], [
+            "{[i,j]: 0<=i<n and 0<=j<5}",
+            ],
+            """
+                b[i] = sum(j, a[(i+j)%n])
+                """,
+            [
+                lp.GlobalArg("a", None, shape="n"),
+                "..."
+                ]
+            )
+
+    print knl
+    print lp.CompiledKernel(ctx, knl).get_highlighted_code(
+            dict(
+                a=np.float32,
+                ))
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
-- 
GitLab