diff --git a/pyopencl/invoker.py b/pyopencl/invoker.py index 3feb76858d85d17839d2babb9a2ec9981af840df..e452aa953a0492b4b4967198f6d5b6c48f23a220 100644 --- a/pyopencl/invoker.py +++ b/pyopencl/invoker.py @@ -58,13 +58,14 @@ def generate_generic_arg_handling_body(num_args): if num_args == 0: gen("pass") + else: + gen_arg_indices = list(range(num_args)) + gen_args = [f"arg{i}" for i in gen_arg_indices] - for i in range(num_args): - gen(f"# process argument {i}") - gen("") - gen(f"current_arg = {i}") - gen(f"self.set_arg({i}, arg{i})") - gen("") + gen(f"self._set_arg_multi(" + f"({', '.join(str(i) for i in gen_arg_indices)},), " + f"({', '.join(gen_args)},)" + ")") return gen @@ -75,8 +76,7 @@ def generate_generic_arg_handling_body(num_args): def generate_specific_arg_handling_body(function_name, num_cl_args, scalar_arg_dtypes, - work_around_arg_count_bug, warn_about_arg_count_bug, - include_debug_helpers): + work_around_arg_count_bug, warn_about_arg_count_bug): assert work_around_arg_count_bug is not None assert warn_about_arg_count_bug is not None @@ -89,15 +89,17 @@ def generate_specific_arg_handling_body(function_name, if not scalar_arg_dtypes: gen("pass") + gen_arg_indices = [] + gen_args = [] + buf_arg_indices = [] + buf_args = [] + for arg_idx, arg_dtype in enumerate(scalar_arg_dtypes): - gen(f"# process argument {arg_idx}") - gen("") - if include_debug_helpers: - gen(f"current_arg = {arg_idx}") arg_var = "arg%d" % arg_idx if arg_dtype is None: - gen(f"self.set_arg({cl_arg_idx}, {arg_var})") + gen_arg_indices.append(cl_arg_idx) + gen_args.append(arg_var) cl_arg_idx += 1 gen("") continue @@ -105,7 +107,8 @@ def generate_specific_arg_handling_body(function_name, arg_dtype = np.dtype(arg_dtype) if arg_dtype.char == "V": - gen(f"self.set_arg({cl_arg_idx}, {arg_var})") + gen_arg_indices.append(cl_arg_idx) + gen_args.append(arg_var) cl_arg_idx += 1 elif arg_dtype.kind == "c": @@ -126,17 +129,11 @@ def generate_specific_arg_handling_body(function_name, if (work_around_arg_count_bug == "pocl" and arg_dtype == np.complex128 and fp_arg_count + 2 <= 8): - gen( - "buf = pack('{arg_char}', {arg_var}.real)" - .format(arg_char=arg_char, arg_var=arg_var)) - gen(f"self._set_arg_buf({cl_arg_idx}, buf)") + buf_arg_indices.append(cl_arg_idx) + buf_args.append(f"pack('{arg_char}', {arg_var}.real)") cl_arg_idx += 1 - if include_debug_helpers: - gen("current_arg = current_arg + 1000") - gen( - "buf = pack('{arg_char}', {arg_var}.imag)" - .format(arg_char=arg_char, arg_var=arg_var)) - gen(f"self._set_arg_buf({cl_arg_idx}, buf)") + buf_arg_indices.append(cl_arg_idx) + buf_args.append(f"pack('{arg_char}', {arg_var}.imag)") cl_arg_idx += 1 elif (work_around_arg_count_bug == "apple" @@ -148,11 +145,9 @@ def generate_specific_arg_handling_body(function_name, "Cannot pass complex numbers to kernels.") else: - gen( - "buf = pack('{arg_char}{arg_char}', " - "{arg_var}.real, {arg_var}.imag)" - .format(arg_char=arg_char, arg_var=arg_var)) - gen(f"self._set_arg_buf({cl_arg_idx}, buf)") + buf_arg_indices.append(cl_arg_idx) + buf_args.append( + f"pack('{arg_char}{arg_char}', {arg_var}.real, {arg_var}.imag)") cl_arg_idx += 1 fp_arg_count += 2 @@ -163,16 +158,23 @@ def generate_specific_arg_handling_body(function_name, arg_char = arg_dtype.char arg_char = _type_char_map.get(arg_char, arg_char) - gen( - "buf = pack('{arg_char}', {arg_var})" - .format( - arg_char=arg_char, - arg_var=arg_var)) - gen(f"self._set_arg_buf({cl_arg_idx}, buf)") + buf_arg_indices.append(cl_arg_idx) + buf_args.append(f"pack('{arg_char}', {arg_var})") cl_arg_idx += 1 gen("") + for arg_kind, indices, args in [ + ("", gen_arg_indices, gen_args), + ("_buf", buf_arg_indices, buf_args) + ]: + assert len(indices) == len(args) + if indices: + gen(f"self._set_arg{arg_kind}_multi(" + f"({', '.join(str(i) for i in indices)},), " + f"({', '.join(args)},)" + ")") + if cl_arg_idx != num_cl_args: raise TypeError( "length of argument list (%d) and " @@ -184,50 +186,10 @@ def generate_specific_arg_handling_body(function_name, # }}} -# {{{ error handler - -def wrap_in_error_handler(body, arg_names): - err_gen = PythonCodeGenerator() - - def gen_error_handler(): - err_gen(""" - if current_arg is not None: - args = [{args}] - advice = "" - from pyopencl.array import Array - if isinstance(args[current_arg], Array): - advice = " (perhaps you meant to pass 'array.data' " \ - "instead of the array itself?)" - - raise _cl.LogicError( - "when processing argument #%d (1-based): %s%s" - % (current_arg+1, str(e), advice)) - else: - raise - """ - .format(args=", ".join(arg_names))) - err_gen("") - - err_gen("try:") - with Indentation(err_gen): - err_gen.extend(body) - err_gen("except TypeError as e:") - with Indentation(err_gen): - gen_error_handler() - err_gen("except _cl.LogicError as e:") - with Indentation(err_gen): - gen_error_handler() - - return err_gen - -# }}} - - def _generate_enqueue_and_set_args_module(function_name, num_passed_args, num_cl_args, scalar_arg_dtypes, - work_around_arg_count_bug, warn_about_arg_count_bug, - include_debug_helpers): + work_around_arg_count_bug, warn_about_arg_count_bug): arg_names = ["arg%d" % i for i in range(num_passed_args)] @@ -237,11 +199,7 @@ def _generate_enqueue_and_set_args_module(function_name, body = generate_specific_arg_handling_body( function_name, num_cl_args, scalar_arg_dtypes, warn_about_arg_count_bug=warn_about_arg_count_bug, - work_around_arg_count_bug=work_around_arg_count_bug, - include_debug_helpers=include_debug_helpers) - - if include_debug_helpers: - body = wrap_in_error_handler(body, arg_names) + work_around_arg_count_bug=work_around_arg_count_bug) gen = PythonCodeGenerator() @@ -291,7 +249,7 @@ def _generate_enqueue_and_set_args_module(function_name, invoker_cache = WriteOncePersistentDict( - "pyopencl-invoker-cache-v21", + "pyopencl-invoker-cache-v29", key_builder=_NumpyTypesKeyBuilder()) @@ -302,8 +260,7 @@ def generate_enqueue_and_set_args(function_name, cache_key = (function_name, num_passed_args, num_cl_args, scalar_arg_dtypes, - work_around_arg_count_bug, warn_about_arg_count_bug, - not sys.flags.optimize) + work_around_arg_count_bug, warn_about_arg_count_bug) from_cache = False diff --git a/src/wrap_cl.hpp b/src/wrap_cl.hpp index 7a6f056f6c9591354f180260c6c6728f33a029d8..3210eb0991ef86a21be9d122d066551920be6bed 100644 --- a/src/wrap_cl.hpp +++ b/src/wrap_cl.hpp @@ -86,6 +86,7 @@ #endif +#include <functional> #include <thread> #include <mutex> #include <condition_variable> @@ -4452,6 +4453,59 @@ namespace pyopencl set_arg_buf(arg_index, arg); } + static + void set_arg_multi( + std::function<void(cl_uint, py::handle)> set_arg_func, + py::tuple indices, + py::tuple args) + { + // This is an internal interface used by generated invokers. + // We can save a tiny bit of time by not checking their work. + /* + if (indices.size() != args.size()) + throw error("Kernel.set_arg_multi", CL_INVALID_VALUE, + "indices and args arguments do not have the same length"); + */ + + cl_uint arg_index; + py::handle arg_value; + + auto indices_it = indices.begin(), args_it = args.begin(), + indices_end = indices.end(); + try + { + while (indices_it != indices_end) + { + arg_index = py::cast<cl_uint>(*indices_it++); + arg_value = *args_it++; + set_arg_func(arg_index, arg_value); + } + } + catch (error &err) + { + std::string msg( + std::string("when processing arg#") + std::to_string(arg_index+1) + + std::string(" (1-based): ") + std::string(err.what())); + + auto mod_cl_ary(py::module::import("pyopencl.array")); + auto cls_array(mod_cl_ary.attr("Array")); + if (arg_value.ptr() && py::isinstance(arg_value, cls_array)) + msg.append( + " (perhaps you meant to pass 'array.data' instead of the array itself?)"); + + throw error(err.routine().c_str(), err.code(), msg.c_str()); + } + catch (std::exception &err) + { + std::string msg( + std::string("when processing arg#") + std::to_string(arg_index+1) + + std::string(" (1-based): ") + std::string(err.what())); + + throw std::runtime_error(msg.c_str()); + } + } + + py::object get_info(cl_kernel_info param_name) const { switch (param_name) diff --git a/src/wrap_cl_part_2.cpp b/src/wrap_cl_part_2.cpp index e68c785128dda3da644a9dbd76c5feb06d7dd591..3dc46c60e9f05b9541263654c0837b60aece4d8d 100644 --- a/src/wrap_cl_part_2.cpp +++ b/src/wrap_cl_part_2.cpp @@ -470,6 +470,20 @@ void pyopencl_expose_part_2(py::module &m) #if PYOPENCL_CL_VERSION >= 0x2000 .def("_set_arg_svm", &cls::set_arg_svm) #endif + .def("_set_arg_multi", + [](cls &knl, py::tuple indices, py::tuple args) + { + cls::set_arg_multi( + [&](cl_uint i, py::handle arg) { knl.set_arg(i, arg); }, + indices, args); + }) + .def("_set_arg_buf_multi", + [](cls &knl, py::tuple indices, py::tuple args) + { + cls::set_arg_multi( + [&](cl_uint i, py::handle arg) { knl.set_arg_buf(i, arg); }, + indices, args); + }) .DEF_SIMPLE_METHOD(set_arg) #if PYOPENCL_CL_VERSION >= 0x1020 .DEF_SIMPLE_METHOD(get_arg_info)