From b66344e0cec02cbe3eb045a798cbfc88a2460467 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Wed, 24 Jan 2018 18:29:35 -0600
Subject: [PATCH] Numpy execution: Enable support for relaxed stride checks
 (closes #121).

---
 loopy/target/execution.py | 22 +++++++++++++++++-----
 test/test_loopy.py        | 20 ++++++++++++++++++++
 2 files changed, 37 insertions(+), 5 deletions(-)

diff --git a/loopy/target/execution.py b/loopy/target/execution.py
index 2aa76e099..facd56a07 100644
--- a/loopy/target/execution.py
+++ b/loopy/target/execution.py
@@ -363,6 +363,10 @@ class ExecutionWrapperGeneratorBase(object):
         from loopy.types import NumpyType
 
         gen("# {{{ set up array arguments")
+
+        gen("")
+        gen("def _lpy_filter_stride(shape, stride):")
+        gen("    return tuple(s for dim, s in zip(shape, stride) if dim > 1)")
         gen("")
 
         if not options.no_numpy:
@@ -516,13 +520,21 @@ class ExecutionWrapperGeneratorBase(object):
                         itemsize = kernel_arg.dtype.numpy_dtype.itemsize
                         sym_strides = tuple(
                                 itemsize*s_i for s_i in arg.unvec_strides)
-                        gen("if %s.strides != %s:"
-                                % (arg.name, strify(sym_strides)))
+                        gen("if _lpy_filter_stride(%s.shape, %s.strides) != "
+                                    "_lpy_filter_stride(%s.shape, %s):"
+                                    % (
+                                        arg.name, arg.name, arg.name,
+                                        strify(sym_strides)))
                         with Indentation(gen):
                             gen("raise TypeError(\"strides mismatch on "
-                                    "argument '%s' (got: %%s, expected: %%s)\" "
-                                    "%% (%s.strides, %s))"
-                                    % (arg.name, arg.name, strify(sym_strides)))
+                                    "argument '%s' "
+                                    "(after removing unit length dims, "
+                                    "got: %%s, expected: %%s)\" "
+                                    "%% (_lpy_filter_stride(%s.shape, %s.strides), "
+                                    "_lpy_filter_stride(%s.shape, %s)))"
+                                    % (
+                                        arg.name, arg.name, arg.name, arg.name,
+                                        strify(sym_strides)))
 
                     if not arg.allows_offset:
                         gen("if hasattr(%s, 'offset') and %s.offset:" % (
diff --git a/test/test_loopy.py b/test/test_loopy.py
index e624ed346..375b59dcb 100644
--- a/test/test_loopy.py
+++ b/test/test_loopy.py
@@ -2746,6 +2746,26 @@ def test_arg_inference_for_predicates():
     assert knl.arg_dict["incr"].shape == (10,)
 
 
+def test_relaxed_stride_checks(ctx_factory):
+    # Check that loopy is compatible with numpy's relaxed stride rules.
+    ctx = ctx_factory()
+
+    knl = lp.make_kernel("{[i,j]: 0 <= i <= n and 0 <= j <= m}",
+             """
+             a[i] = sum(j, A[i,j] * b[j])
+             """)
+
+    with cl.CommandQueue(ctx) as queue:
+        A = np.zeros((1, 10), order="F")
+        # Force convert A to C order. numpy will preserve strides in this case.
+        A = np.array(A, copy=False, order="C")
+        b = np.zeros(10, dtype=np.float64)
+
+        evt, (a,) = knl(queue, A=A, b=b)
+
+        assert a == 0
+
+
 def test_add_prefetch_works_in_lhs_index():
     knl = lp.make_kernel(
             "{ [n,k,l,k1,l1,k2,l2]: "
-- 
GitLab