diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index 620e72d0bc2a3f8445ea274a800609bae66da881..430199b98b62b0e14da0f34441384b9d219054e6 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -1,12 +1,11 @@ -"""Scan primitive.""" +"""Algorithms built on scans.""" -from __future__ import division -from __future__ import absolute_import -from six.moves import range -from six.moves import zip +from __future__ import division, absolute_import -__copyright__ = """Copyright 2011-2012 Andreas Kloeckner \ - Copyright 2017 Hao Gao""" +__copyright__ = """ +Copyright 2011-2012 Andreas Kloeckner +Copyright 2017 Hao Gao +""" __license__ = """ Permission is hereby granted, free of charge, to any person @@ -31,6 +30,8 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from six.moves import range, zip + import numpy as np import pyopencl as cl import pyopencl.array # noqa @@ -40,35 +41,41 @@ from pytools import memoize, memoize_method, Record from mako.template import Template -# {{{ copy_if +# {{{ "extra args" handling utility -_copy_if_template = ScanTemplate( - arguments="item_t *ary, item_t *out, scan_t *count", - input_expr="(%(predicate)s) ? 1 : 0", - scan_expr="a+b", neutral="0", - output_statement=""" - if (prev_item != item) out[item-1] = ary[i]; - if (i+1 == N) *count = item; - """, - template_processor="printf") - - -def extract_extra_args_types_values(extra_args): +def _extract_extra_args_types_values(extra_args): from pyopencl.tools import VectorArg, ScalarArg extra_args_types = [] extra_args_values = [] + extra_wait_for = [] for name, val in extra_args: if isinstance(val, cl.array.Array): extra_args_types.append(VectorArg(val.dtype, name, with_offset=False)) extra_args_values.append(val) + extra_wait_for.extend(val.events) elif isinstance(val, np.generic): extra_args_types.append(ScalarArg(val.dtype, name)) extra_args_values.append(val) else: raise RuntimeError("argument '%d' not understood" % name) - return tuple(extra_args_types), extra_args_values + return tuple(extra_args_types), extra_args_values, extra_wait_for + +# }}} + + +# {{{ copy_if + +_copy_if_template = ScanTemplate( + arguments="item_t *ary, item_t *out, scan_t *count", + input_expr="(%(predicate)s) ? 1 : 0", + scan_expr="a+b", neutral="0", + output_statement=""" + if (prev_item != item) out[item-1] = ary[i]; + if (i+1 == N) *count = item; + """, + template_processor="printf") def copy_if(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=None): @@ -94,7 +101,12 @@ def copy_if(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=Non else: scan_dtype = np.int32 - extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args) + if wait_for is None: + wait_for = [] + + extra_args_types, extra_args_values, extra_wait_for = \ + _extract_extra_args_types_values(extra_args) + wait_for = wait_for + extra_wait_for knl = _copy_if_template.build(ary.context, type_aliases=(("scan_t", scan_dtype), ("item_t", ary.dtype)), @@ -175,7 +187,12 @@ def partition(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=N else: scan_dtype = np.uint32 - extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args) + if wait_for is None: + wait_for = [] + + extra_args_types, extra_args_values, extra_wait_for = \ + _extract_extra_args_types_values(extra_args) + wait_for = wait_for + extra_wait_for knl = _partition_template.build( ary.context, @@ -242,7 +259,12 @@ def unique(ary, is_equal_expr="a == b", extra_args=[], preamble="", else: scan_dtype = np.uint32 - extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args) + if wait_for is None: + wait_for = [] + + extra_args_types, extra_args_values, extra_wait_for = \ + _extract_extra_args_types_values(extra_args) + wait_for = wait_for + extra_wait_for knl = _unique_template.build( ary.context, @@ -1104,6 +1126,9 @@ class ListOfListsBuilder: if wait_for is None: wait_for = [] + else: + # We'll be modifying it below. + wait_for = list(wait_for) count_kernel = self.get_count_kernel(index_dtype) write_kernel = self.get_write_kernel(index_dtype) @@ -1130,6 +1155,7 @@ class ListOfListsBuilder: data_args.append(arg_val.base_data) if arg_descr.with_offset: data_args.append(arg_val.offset) + wait_for.extend(arg_val.events) else: data_args.append(arg_val) diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index cbd8d74666d3b7eca797672ca4f80224dd3f150c..c9822cb257e777ef852ffbbed3f47331d17a46bc 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -1,9 +1,7 @@ """Elementwise functionality.""" -from __future__ import division -from __future__ import absolute_import -from six.moves import range -from six.moves import zip +from __future__ import division, absolute_import +from six.moves import range, zip __copyright__ = "Copyright (C) 2009 Andreas Kloeckner" @@ -250,6 +248,15 @@ class ElementwiseKernel: use_range = range_ is not None or slice_ is not None kernel, arg_descrs = self.get_kernel(use_range) + queue = kwargs.pop("queue", None) + wait_for = kwargs.pop("wait_for", None) + + if wait_for is None: + wait_for = [] + else: + # We'll be modifying it below. + wait_for = list(wait_for) + # {{{ assemble arg array invocation_args = [] @@ -265,13 +272,12 @@ class ElementwiseKernel: invocation_args.append(arg.base_data) if arg_descr.with_offset: invocation_args.append(arg.offset) + wait_for.extend(arg.events) else: invocation_args.append(arg) # }}} - queue = kwargs.pop("queue", None) - wait_for = kwargs.pop("wait_for", None) if kwargs: raise TypeError("unknown keyword arguments: '%s'" % ", ".join(kwargs)) diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py index 0b39dd70552b7293dfaa75b903be66c78b791dd0..7c25f05b50e840b249caece15637610bca79f957 100644 --- a/pyopencl/reduction.py +++ b/pyopencl/reduction.py @@ -304,6 +304,12 @@ class ReductionKernel: return_event = kwargs.pop("return_event", False) out = kwargs.pop("out", None) + if wait_for is None: + wait_for = [] + else: + # We'll be modifying it below. + wait_for = list(wait_for) + range_ = kwargs.pop("range", None) slice_ = kwargs.pop("slice", None) @@ -327,6 +333,7 @@ class ReductionKernel: invocation_args.append(arg.base_data) if arg_tp.with_offset: invocation_args.append(arg.offset) + wait_for.extend(arg.events) else: invocation_args.append(arg) @@ -413,6 +420,8 @@ class ReductionKernel: wait_for=wait_for) wait_for = [last_evt] + result.add_event(last_evt) + if group_count == 1: if return_event: return result, last_evt diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 8ec5043d4e1cfb048ce95501127ee7acd0ea760d..6e40c06c883619758b738069290710112b8ed055 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -1468,6 +1468,11 @@ class GenericScanKernel(_GenericScanKernelBase): n = kwargs.get("size") wait_for = kwargs.get("wait_for") + if wait_for is None: + wait_for = [] + else: + wait_for = list(wait_for) + if len(args) != len(self.parsed_args): raise TypeError("expected %d arguments, got %d" % (len(self.parsed_args), len(args))) @@ -1490,6 +1495,7 @@ class GenericScanKernel(_GenericScanKernelBase): data_args.append(arg_val.base_data) if arg_descr.with_offset: data_args.append(arg_val.offset) + wait_for.extend(arg_val.events) else: data_args.append(arg_val) @@ -1678,6 +1684,12 @@ class GenericDebugScanKernel(_GenericScanKernelBase): n = kwargs.get("size") wait_for = kwargs.get("wait_for") + if wait_for is None: + wait_for = [] + else: + # We'll be modifying it below. + wait_for = list(wait_for) + if len(args) != len(self.parsed_args): raise TypeError("expected %d arguments, got %d" % (len(self.parsed_args), len(args))) @@ -1700,6 +1712,7 @@ class GenericDebugScanKernel(_GenericScanKernelBase): data_args.append(arg_val.base_data) if arg_descr.with_offset: data_args.append(arg_val.offset) + wait_for.extend(arg_val.events) else: data_args.append(arg_val)