From 7129db53f6e93c13a88bb75d6a37c27d6e5010a7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 11 Jul 2013 19:17:58 -0400 Subject: [PATCH] Make modulo in indices less crashy --- loopy/check.py | 3 +++ loopy/kernel/creation.py | 5 ++++- test/test_loopy.py | 23 ++++++++++++++++++++++- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/loopy/check.py b/loopy/check.py index a71bad222..36a50df48 100644 --- a/loopy/check.py +++ b/loopy/check.py @@ -253,6 +253,9 @@ class _AccessCheckMapper(WalkMapper): except isl.Error: # Likely: index was non-linear, nothing we can do. return + except TypeError: + # Likely: index was non-linear, nothing we can do. + return shape_domain = isl.BasicSet.universe(access_range.get_space()) for idim in xrange(len(subscript)): diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index ab2b2eeb9..1f8d0f38c 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -784,7 +784,10 @@ def guess_arg_shape_if_requested(kernel, default_order): except TypeError, e: from loopy.diagnostic import LoopyError raise LoopyError( - "failed to find access range for argument '%s': %s" + "Failed to (automatically, as requested) find " + "shape/strides for argument '%s'. " + "Specifying the shape manually should get rid of this. " + "The following error occurred: %s" % (arg.name, str(e))) if armap.access_range is None: diff --git a/test/test_loopy.py b/test/test_loopy.py index 446de00db..5bccdd5dc 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -1282,10 +1282,31 @@ def test_split_reduction(ctx_factory): "..."]) knl = lp.split_reduction_outward(knl, "j,k") - print knl # FIXME: finish test +def test_modulo_indexing(ctx_factory): + ctx = ctx_factory() + + knl = lp.make_kernel(ctx.devices[0], [ + "{[i,j]: 0<=i<n and 0<=j<5}", + ], + """ + b[i] = sum(j, a[(i+j)%n]) + """, + [ + lp.GlobalArg("a", None, shape="n"), + "..." + ] + ) + + print knl + print lp.CompiledKernel(ctx, knl).get_highlighted_code( + dict( + a=np.float32, + )) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab