diff --git a/doc/algorithm.rst b/doc/algorithm.rst index 25396c4c55163c8c85eabe3d9de2fbf540fcdd33..6eccc2f86683d39bd0bf719b414e3befdef987da 100644 --- a/doc/algorithm.rst +++ b/doc/algorithm.rst @@ -119,8 +119,9 @@ Prefix Sums ("scan") .. module:: pyopencl.scan .. |scan_extra_args| replace:: a list of tuples *(name, value)* specifying - extra arguments to pass to the scan procedure. *value* must be :mod:`numpy` - sized type. + extra arguments to pass to the scan procedure. For version 2013.1, + *value* must be a of a :mod:`numpy` sized scalar type. As of version 2013.2, + *value* may also be a :class:`pyopencl.array.Array`. .. |preamble| replace:: A snippet of C that is inserted into the compiled kernel before the actual kernel function. May be used for, e.g. type definitions or include statements. @@ -231,7 +232,7 @@ Simple / Legacy Interface an associative binary operation. *neutral* is the neutral element of *scan_expr*, obeying *scan_expr(a, neutral) == a*. - *dtype* specifies the type of the arrays being operated on. + *dtype* specifies the type of the arrays being operated on. *name_prefix* is used for kernel names to ensure recognizability in profiles and logs. *options* is a list of compiler options to use when building. *preamble* specifies a string of code that is diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index 2ad82f7b3ca70bf85169474b5fef64d2cf36be65..3dfdab556ca13fc9cc13a26f441617a34f39edc3 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -49,6 +49,24 @@ _copy_if_template = ScanTemplate( template_processor="printf") +def extract_extra_args_types_values(extra_args): + from pyopencl.tools import VectorArg, ScalarArg + + extra_args_types = [] + extra_args_values = [] + for name, val in extra_args: + if isinstance(val, cl.array.Array): + extra_args_types.append(VectorArg(val.dtype, name, with_offset=False)) + extra_args_values.append(val) + elif isinstance(val, np.generic): + extra_args_types.append(ScalarArg(val.dtype, name)) + extra_args_values.append(val) + else: + raise RuntimeError("argument '%d' not understood" % name) + + return tuple(extra_args_types), extra_args_values + + def copy_if(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=None): """Copy the elements of *ary* satisfying *predicate* to an output array. @@ -70,8 +88,7 @@ def copy_if(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=Non else: scan_dtype = np.int32 - extra_args_types = tuple((val.dtype, name) for name, val in extra_args) - extra_args_values = tuple(val for name, val in extra_args) + extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args) knl = _copy_if_template.build(ary.context, type_aliases=(("scan_t", scan_dtype), ("item_t", ary.dtype)), @@ -153,8 +170,7 @@ def partition(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=N else: scan_dtype = np.uint32 - extra_args_types = tuple((val.dtype, name) for name, val in extra_args) - extra_args_values = tuple(val for name, val in extra_args) + extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args) knl = _partition_template.build( ary.context, @@ -222,8 +238,7 @@ def unique(ary, is_equal_expr="a == b", extra_args=[], preamble="", else: scan_dtype = np.uint32 - extra_args_types = tuple((val.dtype, name) for name, val in extra_args) - extra_args_values = tuple(val for name, val in extra_args) + extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args) knl = _unique_template.build( ary.context,