diff --git a/loopy/compiled.py b/loopy/compiled.py index 92997594a3fb0c13e81c02d756e50799b3ad448a..da9d11fdb2300a28365fcd8058c3681821b6ae2d 100644 --- a/loopy/compiled.py +++ b/loopy/compiled.py @@ -12,6 +12,37 @@ import numpy as np class CompiledKernel: def __init__(self, context, kernel, size_args=None, options=[], edit_code=False, codegen_kwargs={}): + import loopy as lp + + # {{{ do scheduling, if not yet done + + needs_check = False + + if kernel.schedule is None: + kernel_count = 0 + + for scheduled_kernel in lp.generate_loop_schedules(kernel): + kernel_count += 1 + + if kernel_count == 1: + # use the first schedule + kernel = scheduled_kernel + + if kernel_count == 2: + from warnings import warn + warn("kernel scheduling was ambiguous--more than one " + "schedule found, ignoring", stacklevel=2) + break + + needs_check = True + + # Whether we need to call check_kernels. Since we don't have parameter + # values now, we'll do that on first invocation. + + self.needs_check = needs_check + + # }}} + self.kernel = kernel from loopy.codegen import generate_code self.code = generate_code(kernel, **codegen_kwargs) @@ -20,7 +51,6 @@ class CompiledKernel: from pytools import invoke_editor self.code = invoke_editor(self.code) - import pyopencl as cl try: self.cl_program = cl.Program(context, self.code) self.cl_kernel = getattr( @@ -47,7 +77,6 @@ class CompiledKernel: self.cl_kernel.set_scalar_arg_dtypes(arg_types) - from pymbolic import compile if size_args is None: self.size_args = kernel.scalar_loop_args else: @@ -58,11 +87,94 @@ class CompiledKernel: if not gsize_expr: gsize_expr = (1,) if not lsize_expr: lsize_expr = (1,) + from pymbolic import compile self.global_size_func = compile( gsize_expr, self.size_args) self.local_size_func = compile( lsize_expr, self.size_args) + def __call__(self, queue, **kwargs): + allocator = kwargs.pop("allocator", None) + wait_for = kwargs.pop("wait_for", None) + out_host = kwargs.pop("out_host", False) + check = kwargs.pop("check", True) + + import loopy as lp + + if check and self.needs_check: + lp.check_kernels([self.kernel], kwargs) + + self.needs_check = False + + domain_parameters = dict((name, kwargs[name]) + for name in self.kernel.scalar_loop_args) + + args = [] + outputs = [] + + for arg in self.kernel.args: + is_written = arg.name in self.kernel.get_written_variables() + + val = kwargs.get(arg.name) + if val is None: + if not is_written: + raise TypeError("must supply input argument '%s'" % arg.name) + + if isinstance(arg, lp.ImageArg): + raise RuntimeError("write-mode image '%s' must " + "be explicitly supplied" % arg.name) + + from pymbolic import evaluate + shape = evaluate(arg.shape, kwargs) + val = cl_array.empty(queue, shape, arg.dtype, order=arg.order, + allocator=allocator) + elif check: + if isinstance(arg, lp.ArrayArg): + from pymbolic import evaluate + shape = evaluate(arg.shape, kwargs) + + if arg.dtype != val.dtype: + raise TypeError("dtype mismatch on argument '%s'" + "(got: %s, expected: %s)" + % (arg.name, val.dtype, arg.dtype)) + if shape != val.shape: + raise TypeError("shape mismatch on argument '%s'" + "(got: %s, expected: %s)" + % (arg.name, val.shape, shape)) + if arg.order == "F" and not val.flags.f_contiguous: + raise TypeError("order mismatch on argument '%s'" + "(expected Fortran-contiguous, but isn't)" + % (arg.name)) + if arg.order == "C" and not val.flags.c_contiguous: + raise TypeError("order mismatch on argument '%s'" + "(expected C-contiguous, but isn't)" + % (arg.name)) + + # automatically transfer host-side arrays + if isinstance(val, np.ndarray) and isinstance(arg, lp.ArrayArg): + # synchronous, so nothing to worry about + val = cl_array.to_device(queue, val, allocator=allocator) + + if is_written: + outputs.append(val) + + if isinstance(arg, lp.ArrayArg): + args.append(val.data) + else: + args.append(val) + + evt = self.cl_kernel(queue, + self.global_size_func(**domain_parameters), + self.local_size_func(**domain_parameters), + *args, + g_times_l=True, wait_for=wait_for) + + if out_host: + outputs = [o.get() for o in outputs] + + return [evt] + outputs + + # }}}