From ce6107f9aaac4786d8dac53558c161c341358615 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 16 May 2021 13:31:14 -0500 Subject: [PATCH] Invoker: streamline generation of wait_for --- pyopencl/invoker.py | 48 ++++++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/pyopencl/invoker.py b/pyopencl/invoker.py index f5647808..568d4bd1 100644 --- a/pyopencl/invoker.py +++ b/pyopencl/invoker.py @@ -75,8 +75,7 @@ def generate_generic_arg_handling_body(num_args): BUF_PACK_TYPECHARS = ["c", "b", "B", "h", "H", "i", "I", "l", "L", "f", "d"] -def generate_specific_arg_handling_body(function_name, - num_cl_args, arg_types, +def generate_specific_arg_handling_body(function_name, num_cl_args, arg_types, *, work_around_arg_count_bug, warn_about_arg_count_bug, in_enqueue, include_debug_code): @@ -104,16 +103,7 @@ def generate_specific_arg_handling_body(function_name, buf_indices_and_args.append(arg_idx) buf_indices_and_args.append(f"pack('{typechar}', {expr_str})") - if in_enqueue and arg_types is not None and \ - any(isinstance(arg_type, VectorArg) for arg_type in arg_types): - # We're about to modify wait_for, make sure it's a copy. - gen(""" - if wait_for is None: - wait_for = [] - else: - wait_for = list(wait_for) - """) - gen("") + wait_for_parts = [] for arg_idx, arg_type in enumerate(arg_types): arg_var = "arg%d" % arg_idx @@ -147,7 +137,7 @@ def generate_specific_arg_handling_body(function_name, cl_arg_idx += 1 if in_enqueue: - gen(f"wait_for.extend({arg_var}.events)") + wait_for_parts .append(f"{arg_var}.events") continue @@ -225,7 +215,10 @@ def generate_specific_arg_handling_body(function_name, "CL-generated number of arguments (%d) do not agree" % (cl_arg_idx, num_cl_args)) - return gen + if in_enqueue: + return gen, wait_for_parts + else: + return gen # }}} @@ -239,7 +232,12 @@ def _generate_enqueue_and_set_args_module(function_name, def gen_arg_setting(in_enqueue): if arg_types is None: - return generate_generic_arg_handling_body(num_passed_args) + result = generate_generic_arg_handling_body(num_passed_args) + if in_enqueue: + return result, [] + else: + return result + else: return generate_specific_arg_handling_body( function_name, num_cl_args, arg_types, @@ -269,13 +267,23 @@ def _generate_enqueue_and_set_args_module(function_name, "wait_for=None"]))) with Indentation(gen): - gen.extend(gen_arg_setting(in_enqueue=True)) + subgen, wait_for_parts = gen_arg_setting(in_enqueue=True) + gen.extend(subgen) + + if wait_for_parts: + wait_for_expr = ( + "[*(() if wait_for is None else wait_for), " + + ", ".join("*"+wfp for wfp in wait_for_parts) + + "]") + else: + wait_for_expr = "wait_for" # Using positional args here because pybind is slow with keyword args - gen(""" - return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size, - global_offset, wait_for, g_times_l, - allow_empty_ndrange) + gen(f""" + return _cl.enqueue_nd_range_kernel(queue, self, + global_size, local_size, global_offset, + {wait_for_expr}, + g_times_l, allow_empty_ndrange) """) # }}} -- GitLab