From ce6107f9aaac4786d8dac53558c161c341358615 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 16 May 2021 13:31:14 -0500
Subject: [PATCH] Invoker: streamline generation of wait_for

---
 pyopencl/invoker.py | 48 ++++++++++++++++++++++++++-------------------
 1 file changed, 28 insertions(+), 20 deletions(-)

diff --git a/pyopencl/invoker.py b/pyopencl/invoker.py
index f5647808..568d4bd1 100644
--- a/pyopencl/invoker.py
+++ b/pyopencl/invoker.py
@@ -75,8 +75,7 @@ def generate_generic_arg_handling_body(num_args):
 BUF_PACK_TYPECHARS = ["c", "b", "B", "h", "H", "i", "I", "l", "L", "f", "d"]
 
 
-def generate_specific_arg_handling_body(function_name,
-        num_cl_args, arg_types,
+def generate_specific_arg_handling_body(function_name, num_cl_args, arg_types, *,
         work_around_arg_count_bug, warn_about_arg_count_bug,
         in_enqueue, include_debug_code):
 
@@ -104,16 +103,7 @@ def generate_specific_arg_handling_body(function_name,
             buf_indices_and_args.append(arg_idx)
             buf_indices_and_args.append(f"pack('{typechar}', {expr_str})")
 
-    if in_enqueue and arg_types is not None and \
-            any(isinstance(arg_type, VectorArg) for arg_type in arg_types):
-        # We're about to modify wait_for, make sure it's a copy.
-        gen("""
-            if wait_for is None:
-                wait_for = []
-            else:
-                wait_for = list(wait_for)
-            """)
-        gen("")
+    wait_for_parts = []
 
     for arg_idx, arg_type in enumerate(arg_types):
         arg_var = "arg%d" % arg_idx
@@ -147,7 +137,7 @@ def generate_specific_arg_handling_body(function_name,
                 cl_arg_idx += 1
 
             if in_enqueue:
-                gen(f"wait_for.extend({arg_var}.events)")
+                wait_for_parts .append(f"{arg_var}.events")
 
             continue
 
@@ -225,7 +215,10 @@ def generate_specific_arg_handling_body(function_name,
             "CL-generated number of arguments (%d) do not agree"
             % (cl_arg_idx, num_cl_args))
 
-    return gen
+    if in_enqueue:
+        return gen, wait_for_parts
+    else:
+        return gen
 
 # }}}
 
@@ -239,7 +232,12 @@ def _generate_enqueue_and_set_args_module(function_name,
 
     def gen_arg_setting(in_enqueue):
         if arg_types is None:
-            return generate_generic_arg_handling_body(num_passed_args)
+            result = generate_generic_arg_handling_body(num_passed_args)
+            if in_enqueue:
+                return result, []
+            else:
+                return result
+
         else:
             return generate_specific_arg_handling_body(
                     function_name, num_cl_args, arg_types,
@@ -269,13 +267,23 @@ def _generate_enqueue_and_set_args_module(function_name,
                         "wait_for=None"])))
 
     with Indentation(gen):
-        gen.extend(gen_arg_setting(in_enqueue=True))
+        subgen, wait_for_parts = gen_arg_setting(in_enqueue=True)
+        gen.extend(subgen)
+
+        if wait_for_parts:
+            wait_for_expr = (
+                    "[*(() if wait_for is None else wait_for), "
+                    + ", ".join("*"+wfp for wfp in wait_for_parts)
+                    + "]")
+        else:
+            wait_for_expr = "wait_for"
 
         # Using positional args here because pybind is slow with keyword args
-        gen("""
-            return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size,
-                    global_offset, wait_for, g_times_l,
-                    allow_empty_ndrange)
+        gen(f"""
+            return _cl.enqueue_nd_range_kernel(queue, self,
+                    global_size, local_size, global_offset,
+                    {wait_for_expr},
+                    g_times_l, allow_empty_ndrange)
             """)
 
     # }}}
-- 
GitLab