Skip to content
Snippets Groups Projects
Commit b66344e0 authored by Matt Wala's avatar Matt Wala
Browse files

Numpy execution: Enable support for relaxed stride checks (closes #121).

parent ed139ad5
No related branches found
No related tags found
No related merge requests found
...@@ -363,6 +363,10 @@ class ExecutionWrapperGeneratorBase(object): ...@@ -363,6 +363,10 @@ class ExecutionWrapperGeneratorBase(object):
from loopy.types import NumpyType from loopy.types import NumpyType
gen("# {{{ set up array arguments") gen("# {{{ set up array arguments")
gen("")
gen("def _lpy_filter_stride(shape, stride):")
gen(" return tuple(s for dim, s in zip(shape, stride) if dim > 1)")
gen("") gen("")
if not options.no_numpy: if not options.no_numpy:
...@@ -516,13 +520,21 @@ class ExecutionWrapperGeneratorBase(object): ...@@ -516,13 +520,21 @@ class ExecutionWrapperGeneratorBase(object):
itemsize = kernel_arg.dtype.numpy_dtype.itemsize itemsize = kernel_arg.dtype.numpy_dtype.itemsize
sym_strides = tuple( sym_strides = tuple(
itemsize*s_i for s_i in arg.unvec_strides) itemsize*s_i for s_i in arg.unvec_strides)
gen("if %s.strides != %s:" gen("if _lpy_filter_stride(%s.shape, %s.strides) != "
% (arg.name, strify(sym_strides))) "_lpy_filter_stride(%s.shape, %s):"
% (
arg.name, arg.name, arg.name,
strify(sym_strides)))
with Indentation(gen): with Indentation(gen):
gen("raise TypeError(\"strides mismatch on " gen("raise TypeError(\"strides mismatch on "
"argument '%s' (got: %%s, expected: %%s)\" " "argument '%s' "
"%% (%s.strides, %s))" "(after removing unit length dims, "
% (arg.name, arg.name, strify(sym_strides))) "got: %%s, expected: %%s)\" "
"%% (_lpy_filter_stride(%s.shape, %s.strides), "
"_lpy_filter_stride(%s.shape, %s)))"
% (
arg.name, arg.name, arg.name, arg.name,
strify(sym_strides)))
if not arg.allows_offset: if not arg.allows_offset:
gen("if hasattr(%s, 'offset') and %s.offset:" % ( gen("if hasattr(%s, 'offset') and %s.offset:" % (
......
...@@ -2746,6 +2746,26 @@ def test_arg_inference_for_predicates(): ...@@ -2746,6 +2746,26 @@ def test_arg_inference_for_predicates():
assert knl.arg_dict["incr"].shape == (10,) assert knl.arg_dict["incr"].shape == (10,)
def test_relaxed_stride_checks(ctx_factory):
# Check that loopy is compatible with numpy's relaxed stride rules.
ctx = ctx_factory()
knl = lp.make_kernel("{[i,j]: 0 <= i <= n and 0 <= j <= m}",
"""
a[i] = sum(j, A[i,j] * b[j])
""")
with cl.CommandQueue(ctx) as queue:
A = np.zeros((1, 10), order="F")
# Force convert A to C order. numpy will preserve strides in this case.
A = np.array(A, copy=False, order="C")
b = np.zeros(10, dtype=np.float64)
evt, (a,) = knl(queue, A=A, b=b)
assert a == 0
def test_add_prefetch_works_in_lhs_index(): def test_add_prefetch_works_in_lhs_index():
knl = lp.make_kernel( knl = lp.make_kernel(
"{ [n,k,l,k1,l1,k2,l2]: " "{ [n,k,l,k1,l1,k2,l2]: "
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment