diff --git a/loopy/target/c/c_execution.py b/loopy/target/c/c_execution.py index c136a9f36f8dd7b797aa9b6875a41e3ea185c0ca..d8b76d32afa64d308648420904f4f4bf8e2e2316 100644 --- a/loopy/target/c/c_execution.py +++ b/loopy/target/c/c_execution.py @@ -105,12 +105,23 @@ class CExecutionWrapperGenerator(ExecutionWrapperGeneratorBase): kernel_arg.dtype.numpy_dtype), order=order)) + expected_strides = tuple( + var("_lpy_expected_strides_%s" % i) + for i in range(num_axes)) + + gen("%s = %s.strides" % (strify(expected_strides), arg.name)) + #check strides if not skip_arg_checks: - gen("assert %(strides)s == %(name)s.strides, " + strides_check_expr = self.get_strides_check_expr( + (strify(s) for s in sym_shape), + (strify(s) for s in sym_strides), + (strify(s) for s in expected_strides)) + gen("assert %(strides_check)s, " "'Strides of loopy created array %(name)s, " "do not match expected.'" % - dict(name=arg.name, + dict(strides_check=strides_check_expr, + name=arg.name, strides=strify(sym_strides))) for i in range(num_axes): gen("del _lpy_shape_%d" % i) diff --git a/loopy/target/execution.py b/loopy/target/execution.py index 2aa76e099d8e50a2949c616736b30f725fb10bb4..3a3ea0a70fe9a9229aa3499ad0bdbfeb87f751ed 100644 --- a/loopy/target/execution.py +++ b/loopy/target/execution.py @@ -351,6 +351,13 @@ class ExecutionWrapperGeneratorBase(object): def get_arg_pass(self, arg): raise NotImplementedError() + def get_strides_check_expr(self, shape, strides, sym_strides): + # Returns an expression suitable for use for checking the strides of an + # argument. Arguments should be sequences of strings. + return " and ".join( + "(%s == 1 or %s == %s)" % elem + for elem in zip(shape, strides, sym_strides)) + # {{{ arg setup def generate_arg_setup( @@ -516,13 +523,34 @@ class ExecutionWrapperGeneratorBase(object): itemsize = kernel_arg.dtype.numpy_dtype.itemsize sym_strides = tuple( itemsize*s_i for s_i in arg.unvec_strides) - gen("if %s.strides != %s:" - % (arg.name, strify(sym_strides))) + + ndim = len(arg.unvec_shape) + shape = ["_lpy_shape_%d" % i for i in range(ndim)] + strides = ["_lpy_stride_%d" % i for i in range(ndim)] + + gen("(%s,) = %s.shape" % (", ".join(shape), arg.name)) + gen("(%s,) = %s.strides" % (", ".join(strides), arg.name)) + + gen("if not %s:" + % self.get_strides_check_expr( + shape, strides, + (strify(s) for s in sym_strides))) with Indentation(gen): + gen("_lpy_got = tuple(stride " + "for (dim, stride) in zip(%s.shape, %s.strides) " + "if dim > 1)" + % (arg.name, arg.name)) + gen("_lpy_expected = tuple(stride " + "for (dim, stride) in zip(%s.shape, %s) " + "if dim > 1)" + % (arg.name, strify_tuple(sym_strides))) + gen("raise TypeError(\"strides mismatch on " - "argument '%s' (got: %%s, expected: %%s)\" " - "%% (%s.strides, %s))" - % (arg.name, arg.name, strify(sym_strides))) + "argument '%s' " + "(after removing unit length dims, " + "got: %%s, expected: %%s)\" " + "%% (_lpy_got, _lpy_expected))" + % arg.name) if not arg.allows_offset: gen("if hasattr(%s, 'offset') and %s.offset:" % ( diff --git a/test/test_loopy.py b/test/test_loopy.py index e624ed346cd696bf18a116e9373f8e765dafdc9a..db18b2a218b653019ac04413d73492766eb850fc 100644 --- a/test/test_loopy.py +++ b/test/test_loopy.py @@ -2746,6 +2746,24 @@ def test_arg_inference_for_predicates(): assert knl.arg_dict["incr"].shape == (10,) +def test_relaxed_stride_checks(ctx_factory): + # Check that loopy is compatible with numpy's relaxed stride rules. + ctx = ctx_factory() + + knl = lp.make_kernel("{[i,j]: 0 <= i <= n and 0 <= j <= m}", + """ + a[i] = sum(j, A[i,j] * b[j]) + """) + + with cl.CommandQueue(ctx) as queue: + mat = np.zeros((1, 10), order="F") + b = np.zeros(10) + + evt, (a,) = knl(queue, A=mat, b=b) + + assert a == 0 + + def test_add_prefetch_works_in_lhs_index(): knl = lp.make_kernel( "{ [n,k,l,k1,l1,k2,l2]: "