diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 527ab76a78dbd8de6c433146781c03150838bf5e..6c22572128717394beb90a2c19752b52b58acf72 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -577,7 +577,7 @@ _PREFIX_WORDS = set(""" index_type interval_begin interval_size offset_end K SCAN_EXPR do_update NO_SEG_BOUNDARY WG_SIZE first_segment_start_in_k_group scan_type - segment_start_in_subtree offset interval_results interval_end N + segment_start_in_subtree offset interval_results interval_end first_segment_start_in_subtree unit_base first_segment_start_in_interval k INPUT_EXPR prev_group_sum prev pv this add value n partial_val pgs OUTPUT_STMT @@ -619,6 +619,7 @@ _IGNORED_WORDS = set(""" wg_size is_i_segment_start_expr a b prev_item i prev_item_unavailable_with_local_update prev_value + N """.split()) def _make_template(s): @@ -684,6 +685,9 @@ class GenericScanKernel(object): The first array in the argument list determines the size of the index space over which the scan is carried out. + + All code fragments further have access to N, the number of elements + being processed in the scan. """ if isinstance(self, ExclusiveScanKernel) and neutral is None: @@ -714,6 +718,9 @@ class GenericScanKernel(object): if isinstance(arg, VectorArg)][0] self.is_segmented = is_i_segment_start_expr is not None + if self.is_segmented: + is_i_segment_start_expr = is_i_segment_start_expr.replace("\n", " ") + use_lookbehind_update = "prev_item" in output_statement if self.is_segmented and use_lookbehind_update: @@ -741,8 +748,8 @@ class GenericScanKernel(object): index_type_max=str(np.iinfo(self.index_dtype).max), scan_ctype=dtype_to_ctype(dtype), is_segmented=self.is_segmented, - scan_expr=scan_expr, - neutral=neutral, + scan_expr=scan_expr.replace("\n", " "), + neutral=neutral.replace("\n", " "), double_support=all( has_double_support(dev) for dev in devices), ) @@ -829,10 +836,10 @@ class GenericScanKernel(object): final_update_tpl = _make_template(update_src) final_update_src = str(final_update_tpl.render( wg_size=self.update_wg_size, - output_statement=output_statement, - argument_signature=arguments, + output_statement=output_statement.replace("\n", " "), + argument_signature=arguments.replace("\n", " "), is_i_segment_start_expr=is_i_segment_start_expr, - input_expr=input_expr, + input_expr=input_expr.replace("\n", " "), **self.code_variables)) final_update_prg = cl.Program(self.context, final_update_src).build(options) @@ -867,7 +874,7 @@ class GenericScanKernel(object): wg_size=wg_size, input_expr=input_expr, k_group_size=k_group_size, - argument_signature=arguments, + argument_signature=arguments.replace("\n", " "), is_i_segment_start_expr=is_i_segment_start_expr, is_first_level=is_first_level, **self.code_variables)) @@ -1047,18 +1054,24 @@ def get_copy_if_kernel(ctx, dtype, predicate): ctype = dtype_to_ctype(dtype) return GenericScanKernel( ctx, np.uint32, - arguments="__global %s *ary, __global %s *out" % (ctype, ctype), + arguments="__global %s *ary, __global %s *out, __global unsigned long *count" % (ctype, ctype), input_expr="(%s) ? 1 : 0" % predicate, scan_expr="a+b", neutral="0", - output_statement="if (prev_item != item) out[item-1] = ary[i];" + output_statement=""" + if (prev_item != item) out[item-1] = ary[i]; + if (i+1 == N) *count = item; + """ ) def copy_if(ary, predicate, queue=None): + # FIXME use 64-bit scan, eventually + # (not relevant for 6GB GPUs) + knl = get_copy_if_kernel(ary.context, ary.dtype, predicate) - # FIXME return count out = cl_array.empty_like(ary) - knl(ary, out, queue=queue) - return out + count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64) + knl(ary, out, count, queue=queue) + return out, count def remove_if(array, predicate, **kwargs): pass diff --git a/test/test_array.py b/test/test_array.py index 984334ba0ee9e57288077f12159806ae8bad2ce2..a540971811a4de9c92f97d950801da9202bf7f70 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -714,6 +714,17 @@ def summarize_error(obtained, desired, orig, thresh=1e-5): return " ".join(entries) +scan_test_counts = [ + 10, + 2 ** 10 - 5, + 2 ** 10, + 2 ** 10 + 5, + 2 ** 20 - 2 ** 18, + 2 ** 20 - 2 ** 18 + 5, + 2 ** 20 + 1, + 2 ** 20, 2 ** 24 + ] + @pytools.test.mark_test.opencl def test_scan(ctx_factory): context = ctx_factory() @@ -728,16 +739,7 @@ def test_scan(ctx_factory): ]: knl = cls(context, dtype, "a+b", "0") - for n in [ - 10, - 2 ** 10 - 5, - 2 ** 10, - 2 ** 10 + 5, - 2 ** 20 - 2 ** 18, - 2 ** 20 - 2 ** 18 + 5, - 2 ** 20 + 1, - 2 ** 20, 2 ** 24 - ]: + for n in scan_test_counts: host_data = np.random.randint(0, 10, n).astype(dtype) dev_data = cl_array.to_device(queue, host_data) @@ -765,15 +767,16 @@ def test_copy_if(ctx_factory): queue = cl.CommandQueue(context) from pyopencl.clrandom import rand as clrand - a_dev = clrand(queue, (200000,), dtype=np.int32, a=0, b=1000) - a = a_dev.get() + for n in scan_test_counts: + a_dev = clrand(queue, (n,), dtype=np.int32, a=0, b=1000) + a = a_dev.get() - from pyopencl.scan import copy_if + from pyopencl.scan import copy_if - selected = a[a>300] - selected_dev = copy_if(a_dev, "ary[i] > 300").get()[:len(selected)] + selected = a[a>300] + selected_dev, count_dev = copy_if(a_dev, "ary[i] > 300") - assert (selected_dev == selected).all() + assert (selected_dev.get()[:count_dev.get()] == selected).all() @pytools.test.mark_test.opencl def test_stride_preservation(ctx_factory):