From b66d0823938a2d3660ad15a0b0a9c4cca9e970cb Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 24 May 2019 17:56:03 -0500 Subject: [PATCH] Fix, test stride mismatch check --- loopy/target/execution.py | 2 +- test/test_loopy.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/loopy/target/execution.py b/loopy/target/execution.py index 3cdf20577..c8f0d4090 100644 --- a/loopy/target/execution.py +++ b/loopy/target/execution.py @@ -531,7 +531,7 @@ class ExecutionWrapperGeneratorBase(object): gen("(%s,) = %s.shape" % (", ".join(shape), arg.name)) gen("(%s,) = %s.strides" % (", ".join(strides), arg.name)) - gen("if not %s:" + gen("if not (%s):" % self.get_strides_check_expr( shape, strides, (strify(s) for s in sym_strides))) diff --git a/test/test_loopy.py b/test/test_loopy.py index 80af89f3b..89b4f5e63 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2970,6 +2970,22 @@ def test_temp_var_type_deprecated_usage(): temp_var_types=(np.dtype(np.int32),)) +def test_shape_mismatch_check(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + prg = lp.make_kernel( + "{[i,j]: 0 <= i < n and 0 <= j < m}", + "c[i] = sum(j, a[i,j]*b[j])", + default_order="F") + + a = np.random.rand(10, 10).astype(np.float32) + b = np.random.rand(10).astype(np.float32) + + with pytest.raises(TypeError, match="strides mismatch"): + prg(queue, a=a, b=b) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab