diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 7562f485cfb2809ad9dc1fc4cdbec45c269c97d8..0db3a2187563e3df31d1eacb659711baebcedb10 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -678,6 +678,9 @@ class GenericScanKernel(object): as *i*. *prev_item* is unavailable when using exclusive scan. *prev_item* in a segmented scan will be the neutral element at a segment boundary, not the immediately preceding item. + + Note that *prev_item enables the construction of an exclusive + scan. :arg is_i_segment_start_expr: If given, makes the scan a segmented scan. Has access to the current index `i` and the input element as `a` and returns a bool. If it returns true, then previous @@ -1050,11 +1053,19 @@ class ExclusiveScanKernel(_ScanKernelBase): # {{{ higher-level trickery @context_dependent_memoize -def get_copy_if_kernel(ctx, dtype, predicate, scan_dtype): +def _get_copy_if_kernel(ctx, dtype, predicate, scan_dtype, extra_args_types): ctype = dtype_to_ctype(dtype) + arguments = [ + "__global %s *ary" % ctype, + "__global %s *out" % ctype, + "__global unsigned long *count", + ] + [ + "%s %s" % (dtype_to_ctype(arg_dtype), name) + for name, arg_dtype in extra_args_types] + return GenericScanKernel( ctx, dtype, - arguments="__global %s *ary, __global %s *out, __global unsigned long *count" % (ctype, ctype), + arguments=", ".join(arguments), input_expr="(%s) ? 1 : 0" % predicate, scan_expr="a+b", neutral="0", output_statement=""" @@ -1063,26 +1074,104 @@ def get_copy_if_kernel(ctx, dtype, predicate, scan_dtype): """ ) -def copy_if(ary, predicate, queue=None): +def copy_if(ary, predicate, extra_args=[], queue=None): + """ + :arg extra_args: a list of tuples *(name, value)* specifying extra + arguments to pass to the scan procedure. + """ if len(ary) > np.iinfo(np.uint32): scan_dtype = np.uint64 else: scan_dtype = np.uint32 - knl = get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype) + extra_args_types = tuple((name, val.dtype) for name, val in extra_args) + extra_args_values = tuple(val for name, val in extra_args) + + knl = _get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype, extra_args_types) out = cl_array.empty_like(ary) count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64) - knl(ary, out, count, queue=queue) + knl(ary, out, count, *extra_args_values, queue=queue) return out, count -def remove_if(array, predicate, **kwargs): - pass +def remove_if(ary, predicate, extra_args=[], queue=None): + return copy_if(ary, "!(%s)" % predicate, extra_args=extra_args, queue=queue) -def partition(array, predicate): - pass +@context_dependent_memoize +def _get_partition_kernel(ctx, dtype, predicate, scan_dtype, extra_args_types): + ctype = dtype_to_ctype(dtype) + arguments = [ + "__global %s *ary" % ctype, + "__global %s *out_true" % ctype, + "__global %s *out_false" % ctype, + "__global unsigned long *count_true", + ] + [ + "%s %s" % (dtype_to_ctype(arg_dtype), name) + for name, arg_dtype in extra_args_types] + + return GenericScanKernel( + ctx, dtype, + arguments=", ".join(arguments), + input_expr="(%s) ? 1 : 0" % predicate, + scan_expr="a+b", neutral="0", + output_statement=""" + if (prev_item != item) + out_true[item-1] = ary[i]; + else + out_false[i-item] = ary[i]; + if (i+1 == N) *count_true = item; + """ + ) + +def partition(ary, predicate, extra_args=[], queue=None): + """ + :arg extra_args: a list of tuples *(name, value)* specifying extra + arguments to pass to the scan procedure. + """ + if len(ary) > np.iinfo(np.uint32): + scan_dtype = np.uint64 + else: + scan_dtype = np.uint32 + + extra_args_types = tuple((name, val.dtype) for name, val in extra_args) + extra_args_values = tuple(val for name, val in extra_args) + + knl = _get_partition_kernel(ary.context, ary.dtype, predicate, scan_dtype, extra_args_types) + out_true = cl_array.empty_like(ary) + out_false = cl_array.empty_like(ary) + count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64) + knl(ary, out_true, out_false, count, *extra_args_values, queue=queue) + return out_true, out_false, count + +@context_dependent_memoize +def _get_unique_by_key_kernel(ctx, dtype, key_expr, scan_dtype, extra_args_types): + ctype = dtype_to_ctype(dtype) + arguments = [ + "__global %s *ary" % ctype, + "__global %s *out" % ctype, + "__global unsigned long *count_true", + ] + [ + "%s %s" % (dtype_to_ctype(arg_dtype), name) + for name, arg_dtype in extra_args_types] + + return GenericScanKernel( + ctx, dtype, + arguments=", ".join(arguments), + input_expr="(%s) ? 1 : 0" % key_expr, + scan_expr="a+b", neutral="0", + output_statement=""" + if (prev_item != item) + out_true[item-1] = ary[i]; + else + out_false[i-item] = ary[i]; + if (i+1 == N) *count_true = item; + """) + +def unique_by_key(array, key_expr, **kwargs): + """ + :arg extra_args: a list of tuples *(name, value)* specifying extra + arguments to pass to the scan procedure. + """ -def unique_by_key(array, key="", **kwargs): - pass # }}} diff --git a/test/test_array.py b/test/test_array.py index a540971811a4de9c92f97d950801da9202bf7f70..32221b69c5c53d8a9d768614e15c66eb4b549de3 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -773,11 +773,34 @@ def test_copy_if(ctx_factory): from pyopencl.scan import copy_if - selected = a[a>300] - selected_dev, count_dev = copy_if(a_dev, "ary[i] > 300") + crit = a_dev.dtype.type(300) + selected = a[a>crit] + selected_dev, count_dev = copy_if(a_dev, "ary[i] > myval", [("myval", crit)]) assert (selected_dev.get()[:count_dev.get()] == selected).all() +@pytools.test.mark_test.opencl +def test_partition(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + from pyopencl.clrandom import rand as clrand + for n in scan_test_counts: + a_dev = clrand(queue, (n,), dtype=np.int32, a=0, b=1000) + a = a_dev.get() + + crit = a_dev.dtype.type(300) + true_host = a[a>crit] + false_host = a[a<=crit] + + from pyopencl.scan import partition + true_dev, false_dev, count_true_dev = partition(a_dev, "ary[i] > myval", [("myval", crit)]) + + count_true_dev = count_true_dev.get() + + assert (true_dev.get()[:count_true_dev] == true_host).all() + assert (false_dev.get()[:n-count_true_dev] == false_host).all() + @pytools.test.mark_test.opencl def test_stride_preservation(ctx_factory): context = ctx_factory()