From fd10efe4ea012e3e459281a8d083992a657a3f3e Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Thu, 25 Jan 2018 11:41:17 -0600
Subject: [PATCH] Change strides check to avoid function calls.

---
 loopy/target/c/c_execution.py | 14 ++++++++---
 loopy/target/execution.py     | 45 +++++++++++++++++++++++------------
 2 files changed, 41 insertions(+), 18 deletions(-)

diff --git a/loopy/target/c/c_execution.py b/loopy/target/c/c_execution.py
index 4fd248c87..bba3a8d56 100644
--- a/loopy/target/c/c_execution.py
+++ b/loopy/target/c/c_execution.py
@@ -105,13 +105,21 @@ class CExecutionWrapperGenerator(ExecutionWrapperGeneratorBase):
                         kernel_arg.dtype.numpy_dtype),
                     order=order))
 
+        expected_strides = tuple(
+                var("_lpy_expected_strides_%s" % i)
+                for i in range(num_axes))
+
+        gen("(%s,) = %s.strides" % (", ".join(expected_strides), arg.name))
+
         #check strides
         if not skip_arg_checks:
-            gen("assert _lpy_filter_stride(%(name)s.shape, %(strides)s) "
-                    "== _lpy_filter_stride(%(name)s.shape, %(name)s.strides), "
+            strides_check_expr = self.get_strides_check_expr(
+                    sym_shape, sym_strides, expected_strides)
+            gen("assert %(strides_check)s, "
                     "'Strides of loopy created array %(name)s, "
                     "do not match expected.'" %
-                    dict(name=arg.name,
+                    dict(strides_check=strides_check_expr,
+                         name=arg.name,
                          strides=strify(sym_strides)))
             for i in range(num_axes):
                 gen("del _lpy_shape_%d" % i)
diff --git a/loopy/target/execution.py b/loopy/target/execution.py
index facd56a07..18d33461c 100644
--- a/loopy/target/execution.py
+++ b/loopy/target/execution.py
@@ -351,6 +351,13 @@ class ExecutionWrapperGeneratorBase(object):
     def get_arg_pass(self, arg):
         raise NotImplementedError()
 
+    def get_strides_check_expr(self, shape, strides, sym_strides):
+        # Returns an expression suitable for use for checking the strides of an
+        # argument.
+        return " and ".join(
+                "(%s == 1 or %s == %s)" % elem
+                for elem in zip(shape, strides, sym_strides))
+
     # {{{ arg setup
 
     def generate_arg_setup(
@@ -364,11 +371,6 @@ class ExecutionWrapperGeneratorBase(object):
 
         gen("# {{{ set up array arguments")
 
-        gen("")
-        gen("def _lpy_filter_stride(shape, stride):")
-        gen("    return tuple(s for dim, s in zip(shape, stride) if dim > 1)")
-        gen("")
-
         if not options.no_numpy:
             gen("_lpy_encountered_numpy = False")
             gen("_lpy_encountered_dev = False")
@@ -520,21 +522,34 @@ class ExecutionWrapperGeneratorBase(object):
                         itemsize = kernel_arg.dtype.numpy_dtype.itemsize
                         sym_strides = tuple(
                                 itemsize*s_i for s_i in arg.unvec_strides)
-                        gen("if _lpy_filter_stride(%s.shape, %s.strides) != "
-                                    "_lpy_filter_stride(%s.shape, %s):"
-                                    % (
-                                        arg.name, arg.name, arg.name,
-                                        strify(sym_strides)))
+
+                        ndim = len(arg.unvec_shape)
+                        shape = ["_lpy_shape_%d" % i for i in range(ndim)]
+                        strides = ["_lpy_stride_%d" % i for i in range(ndim)]
+
+                        gen("(%s,) = %s.shape" % (", ".join(shape), arg.name))
+                        gen("(%s,) = %s.strides" % (", ".join(strides), arg.name))
+
+                        gen("if not (%s):"
+                                % self.get_strides_check_expr(
+                                    shape, strides,
+                                    (strify(s) for s in sym_strides)))
                         with Indentation(gen):
+                            gen("_lpy_got = tuple(stride "
+                                    "for (dim, stride) in zip(%s.shape, %s.strides) "
+                                    "if dim > 1)"
+                                    % (arg.name, arg.name))
+                            gen("_lpy_expected = tuple(stride "
+                                    "for (dim, stride) in zip(%s.shape, %s) "
+                                    "if dim > 1)"
+                                    % (arg.name, strify_tuple(sym_strides)))
+
                             gen("raise TypeError(\"strides mismatch on "
                                     "argument '%s' "
                                     "(after removing unit length dims, "
                                     "got: %%s, expected: %%s)\" "
-                                    "%% (_lpy_filter_stride(%s.shape, %s.strides), "
-                                    "_lpy_filter_stride(%s.shape, %s)))"
-                                    % (
-                                        arg.name, arg.name, arg.name, arg.name,
-                                        strify(sym_strides)))
+                                    "%% (_lpy_got, _lpy_expected))"
+                                    % arg.name)
 
                     if not arg.allows_offset:
                         gen("if hasattr(%s, 'offset') and %s.offset:" % (
-- 
GitLab