From 5d5cb856e06b101eef1c6f5fc05a91f874320840 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Andreas=20Kl=C3=B6ckner?= <inform@tiker.net>
Date: Thu, 14 Nov 2019 23:46:58 +0100
Subject: [PATCH] Wait for Array.events in more places (elwise, reduction,
 scan, algorithms)

---
 pyopencl/algorithm.py   | 74 ++++++++++++++++++++++++++++-------------
 pyopencl/elementwise.py | 18 ++++++----
 pyopencl/reduction.py   |  9 +++++
 pyopencl/scan.py        | 13 ++++++++
 4 files changed, 84 insertions(+), 30 deletions(-)

diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py
index 620e72d0..430199b9 100644
--- a/pyopencl/algorithm.py
+++ b/pyopencl/algorithm.py
@@ -1,12 +1,11 @@
-"""Scan primitive."""
+"""Algorithms built on scans."""
 
-from __future__ import division
-from __future__ import absolute_import
-from six.moves import range
-from six.moves import zip
+from __future__ import division, absolute_import
 
-__copyright__ = """Copyright 2011-2012 Andreas Kloeckner \
-                   Copyright 2017 Hao Gao"""
+__copyright__ = """
+Copyright 2011-2012 Andreas Kloeckner
+Copyright 2017 Hao Gao
+"""
 
 __license__ = """
 Permission is hereby granted, free of charge, to any person
@@ -31,6 +30,8 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
 OTHER DEALINGS IN THE SOFTWARE.
 """
 
+from six.moves import range, zip
+
 import numpy as np
 import pyopencl as cl
 import pyopencl.array  # noqa
@@ -40,35 +41,41 @@ from pytools import memoize, memoize_method, Record
 from mako.template import Template
 
 
-# {{{ copy_if
+# {{{ "extra args" handling utility
 
-_copy_if_template = ScanTemplate(
-        arguments="item_t *ary, item_t *out, scan_t *count",
-        input_expr="(%(predicate)s) ? 1 : 0",
-        scan_expr="a+b", neutral="0",
-        output_statement="""
-            if (prev_item != item) out[item-1] = ary[i];
-            if (i+1 == N) *count = item;
-            """,
-        template_processor="printf")
-
-
-def extract_extra_args_types_values(extra_args):
+def _extract_extra_args_types_values(extra_args):
     from pyopencl.tools import VectorArg, ScalarArg
 
     extra_args_types = []
     extra_args_values = []
+    extra_wait_for = []
     for name, val in extra_args:
         if isinstance(val, cl.array.Array):
             extra_args_types.append(VectorArg(val.dtype, name, with_offset=False))
             extra_args_values.append(val)
+            extra_wait_for.extend(val.events)
         elif isinstance(val, np.generic):
             extra_args_types.append(ScalarArg(val.dtype, name))
             extra_args_values.append(val)
         else:
             raise RuntimeError("argument '%d' not understood" % name)
 
-    return tuple(extra_args_types), extra_args_values
+    return tuple(extra_args_types), extra_args_values, extra_wait_for
+
+# }}}
+
+
+# {{{ copy_if
+
+_copy_if_template = ScanTemplate(
+        arguments="item_t *ary, item_t *out, scan_t *count",
+        input_expr="(%(predicate)s) ? 1 : 0",
+        scan_expr="a+b", neutral="0",
+        output_statement="""
+            if (prev_item != item) out[item-1] = ary[i];
+            if (i+1 == N) *count = item;
+            """,
+        template_processor="printf")
 
 
 def copy_if(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=None):
@@ -94,7 +101,12 @@ def copy_if(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=Non
     else:
         scan_dtype = np.int32
 
-    extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args)
+    if wait_for is None:
+        wait_for = []
+
+    extra_args_types, extra_args_values, extra_wait_for = \
+        _extract_extra_args_types_values(extra_args)
+    wait_for = wait_for + extra_wait_for
 
     knl = _copy_if_template.build(ary.context,
             type_aliases=(("scan_t", scan_dtype), ("item_t", ary.dtype)),
@@ -175,7 +187,12 @@ def partition(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=N
     else:
         scan_dtype = np.uint32
 
-    extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args)
+    if wait_for is None:
+        wait_for = []
+
+    extra_args_types, extra_args_values, extra_wait_for = \
+            _extract_extra_args_types_values(extra_args)
+    wait_for = wait_for + extra_wait_for
 
     knl = _partition_template.build(
             ary.context,
@@ -242,7 +259,12 @@ def unique(ary, is_equal_expr="a == b", extra_args=[], preamble="",
     else:
         scan_dtype = np.uint32
 
-    extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args)
+    if wait_for is None:
+        wait_for = []
+
+    extra_args_types, extra_args_values, extra_wait_for = \
+            _extract_extra_args_types_values(extra_args)
+    wait_for = wait_for + extra_wait_for
 
     knl = _unique_template.build(
             ary.context,
@@ -1104,6 +1126,9 @@ class ListOfListsBuilder:
 
         if wait_for is None:
             wait_for = []
+        else:
+            # We'll be modifying it below.
+            wait_for = list(wait_for)
 
         count_kernel = self.get_count_kernel(index_dtype)
         write_kernel = self.get_write_kernel(index_dtype)
@@ -1130,6 +1155,7 @@ class ListOfListsBuilder:
                 data_args.append(arg_val.base_data)
                 if arg_descr.with_offset:
                     data_args.append(arg_val.offset)
+                wait_for.extend(arg_val.events)
             else:
                 data_args.append(arg_val)
 
diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py
index cbd8d746..c9822cb2 100644
--- a/pyopencl/elementwise.py
+++ b/pyopencl/elementwise.py
@@ -1,9 +1,7 @@
 """Elementwise functionality."""
 
-from __future__ import division
-from __future__ import absolute_import
-from six.moves import range
-from six.moves import zip
+from __future__ import division, absolute_import
+from six.moves import range, zip
 
 __copyright__ = "Copyright (C) 2009 Andreas Kloeckner"
 
@@ -250,6 +248,15 @@ class ElementwiseKernel:
         use_range = range_ is not None or slice_ is not None
         kernel, arg_descrs = self.get_kernel(use_range)
 
+        queue = kwargs.pop("queue", None)
+        wait_for = kwargs.pop("wait_for", None)
+
+        if wait_for is None:
+            wait_for = []
+        else:
+            # We'll be modifying it below.
+            wait_for = list(wait_for)
+
         # {{{ assemble arg array
 
         invocation_args = []
@@ -265,13 +272,12 @@ class ElementwiseKernel:
                 invocation_args.append(arg.base_data)
                 if arg_descr.with_offset:
                     invocation_args.append(arg.offset)
+                wait_for.extend(arg.events)
             else:
                 invocation_args.append(arg)
 
         # }}}
 
-        queue = kwargs.pop("queue", None)
-        wait_for = kwargs.pop("wait_for", None)
         if kwargs:
             raise TypeError("unknown keyword arguments: '%s'"
                     % ", ".join(kwargs))
diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py
index 0b39dd70..7c25f05b 100644
--- a/pyopencl/reduction.py
+++ b/pyopencl/reduction.py
@@ -304,6 +304,12 @@ class ReductionKernel:
         return_event = kwargs.pop("return_event", False)
         out = kwargs.pop("out", None)
 
+        if wait_for is None:
+            wait_for = []
+        else:
+            # We'll be modifying it below.
+            wait_for = list(wait_for)
+
         range_ = kwargs.pop("range", None)
         slice_ = kwargs.pop("slice", None)
 
@@ -327,6 +333,7 @@ class ReductionKernel:
                     invocation_args.append(arg.base_data)
                     if arg_tp.with_offset:
                         invocation_args.append(arg.offset)
+                    wait_for.extend(arg.events)
                 else:
                     invocation_args.append(arg)
 
@@ -413,6 +420,8 @@ class ReductionKernel:
                     wait_for=wait_for)
             wait_for = [last_evt]
 
+            result.add_event(last_evt)
+
             if group_count == 1:
                 if return_event:
                     return result, last_evt
diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 8ec5043d..6e40c06c 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -1468,6 +1468,11 @@ class GenericScanKernel(_GenericScanKernelBase):
         n = kwargs.get("size")
         wait_for = kwargs.get("wait_for")
 
+        if wait_for is None:
+            wait_for = []
+        else:
+            wait_for = list(wait_for)
+
         if len(args) != len(self.parsed_args):
             raise TypeError("expected %d arguments, got %d" %
                     (len(self.parsed_args), len(args)))
@@ -1490,6 +1495,7 @@ class GenericScanKernel(_GenericScanKernelBase):
                 data_args.append(arg_val.base_data)
                 if arg_descr.with_offset:
                     data_args.append(arg_val.offset)
+                wait_for.extend(arg_val.events)
             else:
                 data_args.append(arg_val)
 
@@ -1678,6 +1684,12 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
         n = kwargs.get("size")
         wait_for = kwargs.get("wait_for")
 
+        if wait_for is None:
+            wait_for = []
+        else:
+            # We'll be modifying it below.
+            wait_for = list(wait_for)
+
         if len(args) != len(self.parsed_args):
             raise TypeError("expected %d arguments, got %d" %
                     (len(self.parsed_args), len(args)))
@@ -1700,6 +1712,7 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
                 data_args.append(arg_val.base_data)
                 if arg_descr.with_offset:
                     data_args.append(arg_val.offset)
+                wait_for.extend(arg_val.events)
             else:
                 data_args.append(arg_val)
 
-- 
GitLab