Skip to content
Snippets Groups Projects
Commit fd10efe4 authored by Matt Wala's avatar Matt Wala
Browse files

Change strides check to avoid function calls.

parent 24cc712e
No related branches found
No related tags found
1 merge request!214Numpy args: Enable support for relaxed stride checks (closes #121).
Pipeline #
......@@ -105,13 +105,21 @@ class CExecutionWrapperGenerator(ExecutionWrapperGeneratorBase):
kernel_arg.dtype.numpy_dtype),
order=order))
expected_strides = tuple(
var("_lpy_expected_strides_%s" % i)
for i in range(num_axes))
gen("(%s,) = %s.strides" % (", ".join(expected_strides), arg.name))
#check strides
if not skip_arg_checks:
gen("assert _lpy_filter_stride(%(name)s.shape, %(strides)s) "
"== _lpy_filter_stride(%(name)s.shape, %(name)s.strides), "
strides_check_expr = self.get_strides_check_expr(
sym_shape, sym_strides, expected_strides)
gen("assert %(strides_check)s, "
"'Strides of loopy created array %(name)s, "
"do not match expected.'" %
dict(name=arg.name,
dict(strides_check=strides_check_expr,
name=arg.name,
strides=strify(sym_strides)))
for i in range(num_axes):
gen("del _lpy_shape_%d" % i)
......
......@@ -351,6 +351,13 @@ class ExecutionWrapperGeneratorBase(object):
def get_arg_pass(self, arg):
raise NotImplementedError()
def get_strides_check_expr(self, shape, strides, sym_strides):
# Returns an expression suitable for use for checking the strides of an
# argument.
return " and ".join(
"(%s == 1 or %s == %s)" % elem
for elem in zip(shape, strides, sym_strides))
# {{{ arg setup
def generate_arg_setup(
......@@ -364,11 +371,6 @@ class ExecutionWrapperGeneratorBase(object):
gen("# {{{ set up array arguments")
gen("")
gen("def _lpy_filter_stride(shape, stride):")
gen(" return tuple(s for dim, s in zip(shape, stride) if dim > 1)")
gen("")
if not options.no_numpy:
gen("_lpy_encountered_numpy = False")
gen("_lpy_encountered_dev = False")
......@@ -520,21 +522,34 @@ class ExecutionWrapperGeneratorBase(object):
itemsize = kernel_arg.dtype.numpy_dtype.itemsize
sym_strides = tuple(
itemsize*s_i for s_i in arg.unvec_strides)
gen("if _lpy_filter_stride(%s.shape, %s.strides) != "
"_lpy_filter_stride(%s.shape, %s):"
% (
arg.name, arg.name, arg.name,
strify(sym_strides)))
ndim = len(arg.unvec_shape)
shape = ["_lpy_shape_%d" % i for i in range(ndim)]
strides = ["_lpy_stride_%d" % i for i in range(ndim)]
gen("(%s,) = %s.shape" % (", ".join(shape), arg.name))
gen("(%s,) = %s.strides" % (", ".join(strides), arg.name))
gen("if not (%s):"
% self.get_strides_check_expr(
shape, strides,
(strify(s) for s in sym_strides)))
with Indentation(gen):
gen("_lpy_got = tuple(stride "
"for (dim, stride) in zip(%s.shape, %s.strides) "
"if dim > 1)"
% (arg.name, arg.name))
gen("_lpy_expected = tuple(stride "
"for (dim, stride) in zip(%s.shape, %s) "
"if dim > 1)"
% (arg.name, strify_tuple(sym_strides)))
gen("raise TypeError(\"strides mismatch on "
"argument '%s' "
"(after removing unit length dims, "
"got: %%s, expected: %%s)\" "
"%% (_lpy_filter_stride(%s.shape, %s.strides), "
"_lpy_filter_stride(%s.shape, %s)))"
% (
arg.name, arg.name, arg.name, arg.name,
strify(sym_strides)))
"%% (_lpy_got, _lpy_expected))"
% arg.name)
if not arg.allows_offset:
gen("if hasattr(%s, 'offset') and %s.offset:" % (
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment