diff --git a/loopy/compiled.py b/loopy/compiled.py index fc6c5969b8c6ed16efd8e0643a3ef50f9d5af101..125d07d727cef80d24c89f463a89f7b17c8e05d2 100644 --- a/loopy/compiled.py +++ b/loopy/compiled.py @@ -7,6 +7,36 @@ import numpy as np +# {{{ argument checking + +def _arg_matches_spec(arg, val, other_args): + import loopy as lp + if isinstance(arg, lp.ArrayArg): + from pymbolic import evaluate + shape = evaluate(arg.shape, other_args) + + if arg.dtype != val.dtype: + raise TypeError("dtype mismatch on argument '%s' " + "(got: %s, expected: %s)" + % (arg.name, val.dtype, arg.dtype)) + if shape != val.shape: + raise TypeError("shape mismatch on argument '%s' " + "(got: %s, expected: %s)" + % (arg.name, val.shape, shape)) + if arg.order == "F" and not val.flags.f_contiguous: + raise TypeError("order mismatch on argument '%s' " + "(expected Fortran-contiguous, but isn't)" + % (arg.name)) + if arg.order == "C" and not val.flags.c_contiguous: + print id(val), val.flags + raise TypeError("order mismatch on argument '%s' " + "(expected C-contiguous, but isn't)" + % (arg.name)) + + return True + +# }}} + # {{{ compiled kernel object class CompiledKernel: @@ -101,12 +131,11 @@ class CompiledKernel: allocator = kwargs.pop("allocator", None) wait_for = kwargs.pop("wait_for", None) out_host = kwargs.pop("out_host", None) - check = kwargs.pop("check", True) import loopy as lp - if check and self.needs_check: - lp.check_kernels([self.kernel], kwargs) + if self.needs_check: + assert len(list(lp.check_kernels([self.kernel], kwargs))) == 1 self.needs_check = False @@ -133,27 +162,8 @@ class CompiledKernel: shape = evaluate(arg.shape, kwargs) val = cl_array.empty(queue, shape, arg.dtype, order=arg.order, allocator=allocator) - elif check: - if isinstance(arg, lp.ArrayArg): - from pymbolic import evaluate - shape = evaluate(arg.shape, kwargs) - - if arg.dtype != val.dtype: - raise TypeError("dtype mismatch on argument '%s'" - "(got: %s, expected: %s)" - % (arg.name, val.dtype, arg.dtype)) - if shape != val.shape: - raise TypeError("shape mismatch on argument '%s'" - "(got: %s, expected: %s)" - % (arg.name, val.shape, shape)) - if arg.order == "F" and not val.flags.f_contiguous: - raise TypeError("order mismatch on argument '%s'" - "(expected Fortran-contiguous, but isn't)" - % (arg.name)) - if arg.order == "C" and not val.flags.c_contiguous: - raise TypeError("order mismatch on argument '%s'" - "(expected C-contiguous, but isn't)" - % (arg.name)) + else: + assert _arg_matches_spec(arg, val, kwargs) # automatically transfer host-side arrays if isinstance(arg, lp.ArrayArg):