diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 567d56eb10400559b6af080afb6edf79951d66ee..75e4a83e0dd904859b9d2633b9049dd5356598ef 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -41,7 +41,7 @@ from pyopencl._cluda import CLUDA_PREAMBLE -SHARED_PREAMBLE = CLUDA_PREAMBLE + """ +SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL// #define WG_SIZE ${wg_size} /* SCAN_EXPR has no right know the indices it is scanning at because @@ -49,6 +49,10 @@ each index may occur an undetermined number of times in the scan tree, and thus index-based side computations cannot be meaningful. */ #define SCAN_EXPR(a, b) ${scan_expr} +#define INPUT_EXPR(i) (${input_expr}) +%if is_segmented: + #define IS_SEG_START(i, a) (${is_i_segment_start_expr}) +%endif ${preamble} @@ -66,11 +70,6 @@ typedef ${index_ctype} index_type; SCAN_INTERVALS_SOURCE = mako.template.Template(SHARED_PREAMBLE + """//CL// #define K ${k_group_size} -%if is_segmented: - #define IS_SEG_START(i, a) (${is_i_segment_start_expr}) -%endif -#define INPUT_EXPR(i) (${input_expr}) - KERNEL REQD_WG_SIZE(WG_SIZE, 1, 1) @@ -437,7 +436,8 @@ void ${name_prefix}_final_update( if(i < interval_end) { - scan_type value = SCAN_EXPR(prev_group_sum, *partial_scan_buffer); + scan_type val = partial_scan_buffer[i]; + scan_type value = SCAN_EXPR(prev_group_sum, val); OUTPUT_STMT(i, value) } } @@ -450,9 +450,7 @@ void ${name_prefix}_final_update( EXCLUSIVE_UPDATE_SOURCE = mako.template.Template(SHARED_PREAMBLE + """//CL// - borked for now // FIXME - -#define OUTPUT_STMT(i, a) ${output_stmt} +#define OUTPUT_STMT(i, a) ${output_statement} KERNEL REQD_WG_SIZE(WG_SIZE, 1, 1) @@ -462,22 +460,22 @@ void ${name_prefix}_final_update( const index_type interval_size, GLOBAL_MEM scan_type *interval_results, GLOBAL_MEM scan_type *partial_scan_buffer + %if is_segmented: + , GLOBAL_MEM index_type *g_first_segment_start_in_interval + %endif ) { LOCAL_MEM scan_type ldata[WG_SIZE]; const index_type interval_begin = interval_size * GID_0; - const index_type interval_end = min(interval_begin + interval_size, N); + const index_type interval_end = min(interval_begin + interval_size, N); // value to add to this segment scan_type carry = ${neutral}; if(GID_0 != 0) - { - scan_type tmp = interval_results[GID_0 - 1]; - carry = SCAN_EXPR(carry, tmp); - } + carry = interval_results[GID_0 - 1]; - scan_type value = carry; + scan_type value = carry; // (A) for (index_type unit_base = interval_begin; unit_base < interval_end; @@ -485,28 +483,39 @@ void ${name_prefix}_final_update( { const index_type i = unit_base + LID_0; + // load a work group's worth of data if (i < interval_end) { - scan_type tmp = interval_results[i]; + scan_type tmp = partial_scan_buffer[i]; ldata[LID_0] = SCAN_EXPR(carry, tmp); } local_barrier(); + // perform right shift if (LID_0 != 0) value = ldata[LID_0 - 1]; /* - else (see above) - value = carry OR last tail; + else + value = carry (see (A)) OR last tail (see (B)); */ + %if is_segmented: + { + scan_type scan_item_at_i = INPUT_EXPR(i) + bool is_seg_start = IS_SEG_START(i, scan_item_at_i); + if (is_seg_start) + value = ${neutral}; + } + %endif + if (i < interval_end) { OUTPUT_STMT(i, value) } if(LID_0 == 0) - value = ldata[WG_SIZE - 1]; + value = ldata[WG_SIZE - 1]; // (B) local_barrier(); } @@ -608,15 +617,21 @@ class _GenericScanKernelBase(object): i for i, arg in enumerate(self.parsed_args) if isinstance(arg, VectorArg)][0] - if partial_scan_buffer_name is not None: + self.is_segmented = is_i_segment_start_expr is not None + + if self.is_segmented and self.is_exclusive: + # The final update in segmented exclusive scan must be able to + # reconstruct where the segment boundaries were, and therefore + # can't overwrite any of the input. + partial_scan_buffer_name = None + + if partial_scan_buffer_name is not None: self.partial_scan_buffer_idx, = [ i for i, arg in enumerate(self.parsed_args) if arg.name == partial_scan_buffer_name] else: self.partial_scan_buffer_idx = None - self.is_segmented = is_i_segment_start_expr is not None - # {{{ set up shared code dict from pytools import all @@ -713,6 +728,8 @@ class _GenericScanKernelBase(object): wg_size=self.update_wg_size, output_statement=output_statement, argument_signature=arguments, + is_i_segment_start_expr=is_i_segment_start_expr, + input_expr=input_expr, **self.code_variables)) final_update_prg = cl.Program(self.context, final_update_src).build(options) @@ -805,8 +822,8 @@ class _GenericScanKernelBase(object): interval_size, num_intervals = uniform_interval_splitting( n, unit_size, max_intervals) - print "n:%d interval_size: %d num_intervals: %d k_group_size:%d" % ( - n, interval_size, num_intervals, l1_info.k_group_size) + #print "n:%d interval_size: %d num_intervals: %d k_group_size:%d" % ( + #n, interval_size, num_intervals, l1_info.k_group_size) # {{{ first level scan of interval (one interval per block) @@ -865,9 +882,11 @@ class _GenericScanKernelBase(object): class GenericInclusiveScanKernel(_GenericScanKernelBase): final_update_tp = INCLUSIVE_UPDATE_SOURCE + is_exclusive = False class GenericExclusiveScanKernel(_GenericScanKernelBase): final_update_tp = EXCLUSIVE_UPDATE_SOURCE + is_exclusive = True class _ScanKernelBase(_GenericScanKernelBase): def __init__(self, ctx, dtype, @@ -914,8 +933,10 @@ class _ScanKernelBase(_GenericScanKernelBase): class InclusiveScanKernel(_ScanKernelBase): final_update_tp = INCLUSIVE_UPDATE_SOURCE + is_exclusive = False class ExclusiveScanKernel(_ScanKernelBase): final_update_tp = EXCLUSIVE_UPDATE_SOURCE + is_exclusive = True # vim: filetype=pyopencl:fdm=marker diff --git a/test/test_array.py b/test/test_array.py index 8e5a56ac4a56b872189b92239931f79027ef6e73..d3d2f52fc654e9c2d6c13271942fa39adc6cadc5 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -704,8 +704,8 @@ def summarize_error(obtained, desired, orig, thresh=1e-5): entries.append("<%d ok>" % ok_count) ok_count = 0 - entries.append("%r (want: %r, diff: %r, orig: %r)" % (obtained[i], desired[i], - obtained[i]-desired[i], orig[i])) + entries.append("%r (want: %r, got: %r, orig: %r)" % (obtained[i], desired[i], + obtained[i], orig[i])) else: ok_count += 1 @@ -729,13 +729,15 @@ def test_scan(ctx_factory): knl = cls(context, dtype, "a+b", "0") for n in [ - 10, 2 ** 10 - 5, 2 ** 10, - 2 ** 20 - 2 ** 18, - 2 ** 20 - 2 ** 18 + 5, - 2 ** 10 + 5, - 2 ** 20 + 1, - 2 ** 20, 2 ** 24 - ]: + 10, + 2 ** 10 - 5, + 2 ** 10, + 2 ** 10 + 5, + 2 ** 20 - 2 ** 18, + 2 ** 20 - 2 ** 18 + 5, + 2 ** 20 + 1, + 2 ** 20, 2 ** 24 + ]: host_data = np.random.randint(0, 10, n).astype(dtype) dev_data = cl_array.to_device(queue, host_data) @@ -752,6 +754,7 @@ def test_scan(ctx_factory): if 0 and not is_ok: print(summarize_error(dev_data.get(), desired_result, host_data)) + print n, is_ok assert is_ok from gc import collect collect()