From fd10efe4ea012e3e459281a8d083992a657a3f3e Mon Sep 17 00:00:00 2001 From: Matt Wala <wala1@illinois.edu> Date: Thu, 25 Jan 2018 11:41:17 -0600 Subject: [PATCH] Change strides check to avoid function calls. --- loopy/target/c/c_execution.py | 14 ++++++++--- loopy/target/execution.py | 45 +++++++++++++++++++++++------------ 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/loopy/target/c/c_execution.py b/loopy/target/c/c_execution.py index 4fd248c87..bba3a8d56 100644 --- a/loopy/target/c/c_execution.py +++ b/loopy/target/c/c_execution.py @@ -105,13 +105,21 @@ class CExecutionWrapperGenerator(ExecutionWrapperGeneratorBase): kernel_arg.dtype.numpy_dtype), order=order)) + expected_strides = tuple( + var("_lpy_expected_strides_%s" % i) + for i in range(num_axes)) + + gen("(%s,) = %s.strides" % (", ".join(expected_strides), arg.name)) + #check strides if not skip_arg_checks: - gen("assert _lpy_filter_stride(%(name)s.shape, %(strides)s) " - "== _lpy_filter_stride(%(name)s.shape, %(name)s.strides), " + strides_check_expr = self.get_strides_check_expr( + sym_shape, sym_strides, expected_strides) + gen("assert %(strides_check)s, " "'Strides of loopy created array %(name)s, " "do not match expected.'" % - dict(name=arg.name, + dict(strides_check=strides_check_expr, + name=arg.name, strides=strify(sym_strides))) for i in range(num_axes): gen("del _lpy_shape_%d" % i) diff --git a/loopy/target/execution.py b/loopy/target/execution.py index facd56a07..18d33461c 100644 --- a/loopy/target/execution.py +++ b/loopy/target/execution.py @@ -351,6 +351,13 @@ class ExecutionWrapperGeneratorBase(object): def get_arg_pass(self, arg): raise NotImplementedError() + def get_strides_check_expr(self, shape, strides, sym_strides): + # Returns an expression suitable for use for checking the strides of an + # argument. + return " and ".join( + "(%s == 1 or %s == %s)" % elem + for elem in zip(shape, strides, sym_strides)) + # {{{ arg setup def generate_arg_setup( @@ -364,11 +371,6 @@ class ExecutionWrapperGeneratorBase(object): gen("# {{{ set up array arguments") - gen("") - gen("def _lpy_filter_stride(shape, stride):") - gen(" return tuple(s for dim, s in zip(shape, stride) if dim > 1)") - gen("") - if not options.no_numpy: gen("_lpy_encountered_numpy = False") gen("_lpy_encountered_dev = False") @@ -520,21 +522,34 @@ class ExecutionWrapperGeneratorBase(object): itemsize = kernel_arg.dtype.numpy_dtype.itemsize sym_strides = tuple( itemsize*s_i for s_i in arg.unvec_strides) - gen("if _lpy_filter_stride(%s.shape, %s.strides) != " - "_lpy_filter_stride(%s.shape, %s):" - % ( - arg.name, arg.name, arg.name, - strify(sym_strides))) + + ndim = len(arg.unvec_shape) + shape = ["_lpy_shape_%d" % i for i in range(ndim)] + strides = ["_lpy_stride_%d" % i for i in range(ndim)] + + gen("(%s,) = %s.shape" % (", ".join(shape), arg.name)) + gen("(%s,) = %s.strides" % (", ".join(strides), arg.name)) + + gen("if not (%s):" + % self.get_strides_check_expr( + shape, strides, + (strify(s) for s in sym_strides))) with Indentation(gen): + gen("_lpy_got = tuple(stride " + "for (dim, stride) in zip(%s.shape, %s.strides) " + "if dim > 1)" + % (arg.name, arg.name)) + gen("_lpy_expected = tuple(stride " + "for (dim, stride) in zip(%s.shape, %s) " + "if dim > 1)" + % (arg.name, strify_tuple(sym_strides))) + gen("raise TypeError(\"strides mismatch on " "argument '%s' " "(after removing unit length dims, " "got: %%s, expected: %%s)\" " - "%% (_lpy_filter_stride(%s.shape, %s.strides), " - "_lpy_filter_stride(%s.shape, %s)))" - % ( - arg.name, arg.name, arg.name, arg.name, - strify(sym_strides))) + "%% (_lpy_got, _lpy_expected))" + % arg.name) if not arg.allows_offset: gen("if hasattr(%s, 'offset') and %s.offset:" % ( -- GitLab