Skip to content
Snippets Groups Projects
Commit 4ebc1af1 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Make checking in CompiledKernel use assert.

parent 06c0bfc2
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,36 @@ import numpy as np ...@@ -7,6 +7,36 @@ import numpy as np
# {{{ argument checking
def _arg_matches_spec(arg, val, other_args):
import loopy as lp
if isinstance(arg, lp.ArrayArg):
from pymbolic import evaluate
shape = evaluate(arg.shape, other_args)
if arg.dtype != val.dtype:
raise TypeError("dtype mismatch on argument '%s' "
"(got: %s, expected: %s)"
% (arg.name, val.dtype, arg.dtype))
if shape != val.shape:
raise TypeError("shape mismatch on argument '%s' "
"(got: %s, expected: %s)"
% (arg.name, val.shape, shape))
if arg.order == "F" and not val.flags.f_contiguous:
raise TypeError("order mismatch on argument '%s' "
"(expected Fortran-contiguous, but isn't)"
% (arg.name))
if arg.order == "C" and not val.flags.c_contiguous:
print id(val), val.flags
raise TypeError("order mismatch on argument '%s' "
"(expected C-contiguous, but isn't)"
% (arg.name))
return True
# }}}
# {{{ compiled kernel object # {{{ compiled kernel object
class CompiledKernel: class CompiledKernel:
...@@ -101,12 +131,11 @@ class CompiledKernel: ...@@ -101,12 +131,11 @@ class CompiledKernel:
allocator = kwargs.pop("allocator", None) allocator = kwargs.pop("allocator", None)
wait_for = kwargs.pop("wait_for", None) wait_for = kwargs.pop("wait_for", None)
out_host = kwargs.pop("out_host", None) out_host = kwargs.pop("out_host", None)
check = kwargs.pop("check", True)
import loopy as lp import loopy as lp
if check and self.needs_check: if self.needs_check:
lp.check_kernels([self.kernel], kwargs) assert len(list(lp.check_kernels([self.kernel], kwargs))) == 1
self.needs_check = False self.needs_check = False
...@@ -133,27 +162,8 @@ class CompiledKernel: ...@@ -133,27 +162,8 @@ class CompiledKernel:
shape = evaluate(arg.shape, kwargs) shape = evaluate(arg.shape, kwargs)
val = cl_array.empty(queue, shape, arg.dtype, order=arg.order, val = cl_array.empty(queue, shape, arg.dtype, order=arg.order,
allocator=allocator) allocator=allocator)
elif check: else:
if isinstance(arg, lp.ArrayArg): assert _arg_matches_spec(arg, val, kwargs)
from pymbolic import evaluate
shape = evaluate(arg.shape, kwargs)
if arg.dtype != val.dtype:
raise TypeError("dtype mismatch on argument '%s'"
"(got: %s, expected: %s)"
% (arg.name, val.dtype, arg.dtype))
if shape != val.shape:
raise TypeError("shape mismatch on argument '%s'"
"(got: %s, expected: %s)"
% (arg.name, val.shape, shape))
if arg.order == "F" and not val.flags.f_contiguous:
raise TypeError("order mismatch on argument '%s'"
"(expected Fortran-contiguous, but isn't)"
% (arg.name))
if arg.order == "C" and not val.flags.c_contiguous:
raise TypeError("order mismatch on argument '%s'"
"(expected C-contiguous, but isn't)"
% (arg.name))
# automatically transfer host-side arrays # automatically transfer host-side arrays
if isinstance(arg, lp.ArrayArg): if isinstance(arg, lp.ArrayArg):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment