From d9aeaceba4f4e6872c3413d76e68e9345e8cdf66 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 22 Jul 2012 20:39:38 -0500 Subject: [PATCH] First passing tests on segmented scan of arbitrary sizes. --- pyopencl/scan.py | 274 ++++++++++++++++++++++++++------------------- test/test_array.py | 80 +++++++++++-- 2 files changed, 233 insertions(+), 121 deletions(-) diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 35a23cfa..9173f5de 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -54,7 +54,7 @@ and thus index-based side computations cannot be meaningful. */ #define SCAN_EXPR(a, b) ${scan_expr} #define INPUT_EXPR(i) (${input_expr}) %if is_segmented: - #define IS_SEG_START(i, a) (${is_i_segment_start_expr}) + #define IS_SEG_START(i, a) (${is_segment_start_expr}) %endif ${preamble} @@ -111,7 +111,7 @@ typedef ${index_ctype} index_type; # Exclusive scan is realized by performing a right-shift inside # the final update. -SCAN_INTERVALS_SOURCE = SHARED_PREAMBLE + """//CL// +SCAN_INTERVALS_SOURCE = SHARED_PREAMBLE + r"""//CL// #define K ${k_group_size} @@ -138,6 +138,16 @@ void ${name_prefix}_scan_intervals( // index K in first dimension used for carry storage LOCAL_MEM scan_type ldata[K + 1][WG_SIZE + 1]; + %if is_segmented: + LOCAL_MEM bool l_is_segment_start[K][WG_SIZE]; + LOCAL_MEM index_type l_first_segment_start_in_subtree[WG_SIZE]; + + // only relevant/populated for local id 0 + index_type first_segment_start_in_interval = NO_SEG_BOUNDARY; + + index_type first_segment_start_in_k_group, first_segment_start_in_subtree; + %endif + // {{{ set up local data for input_fetch_exprs if any of them are stenciled <% @@ -157,13 +167,6 @@ void ${name_prefix}_scan_intervals( // }}} - %if is_segmented: - index_type first_segment_start_in_interval = NO_SEG_BOUNDARY; - LOCAL_MEM index_type l_first_segment_start_in_k_group[WG_SIZE]; - index_type first_segment_start_in_k_group; - %endif - - const index_type interval_begin = interval_size * GID_0; const index_type interval_end = min(interval_begin + interval_size, N); @@ -235,7 +238,15 @@ void ${name_prefix}_scan_intervals( %endif %endfor - ldata[offset % K][offset / K] = INPUT_EXPR(read_i); + scan_type scan_value = INPUT_EXPR(read_i); + + const index_type o_mod_k = offset % K; + const index_type o_div_k = offset / K; + ldata[o_mod_k][offset / K] = scan_value; + + %if is_segmented: + l_is_segment_start[o_mod_k][o_div_k] = IS_SEG_START(read_i, scan_value); + %endif } } @@ -244,23 +255,22 @@ void ${name_prefix}_scan_intervals( // {{{ carry in from previous unit, if applicable %if is_segmented: - if (LID_0 == 0 && unit_base != interval_begin) - { - if (IS_SEG_START(unit_base, ldata[0][0])) - first_segment_start_in_k_group = unit_base; - else - { - ldata[0][0] = SCAN_EXPR(ldata[K][WG_SIZE - 1], ldata[0][0]); - first_segment_start_in_k_group = NO_SEG_BOUNDARY; - } - } - else - first_segment_start_in_k_group = NO_SEG_BOUNDARY; - %else: - if (LID_0 == 0 && unit_base != interval_begin) - ldata[0][0] = SCAN_EXPR(ldata[K][WG_SIZE - 1], ldata[0][0]); + local_barrier(); + + first_segment_start_in_k_group = NO_SEG_BOUNDARY; + if (l_is_segment_start[0][LID_0]) + first_segment_start_in_k_group = unit_base + K*LID_0; %endif + if (LID_0 == 0 && unit_base != interval_begin + %if is_segmented: + && !l_is_segment_start[0][0] + %endif + ) + { + ldata[0][0] = SCAN_EXPR(ldata[K][WG_SIZE - 1], ldata[0][0]); + } + // }}} local_barrier(); @@ -283,9 +293,11 @@ void ${name_prefix}_scan_intervals( index_type seq_i = unit_base + K*LID_0 + k; %if is_segmented: - if (IS_SEG_START(seq_i, tmp) + if (l_is_segment_start[k][LID_0]) { - first_segment_start_in_k_group = seq_i; + first_segment_start_in_k_group = min( + first_segment_start_in_k_group, + seq_i); sum = tmp; } else @@ -302,7 +314,7 @@ void ${name_prefix}_scan_intervals( ldata[K][LID_0] = sum; %if is_segmented: - l_first_segment_start_in_k_group[LID_0] = first_segment_start_in_k_group; + l_first_segment_start_in_subtree[LID_0] = first_segment_start_in_k_group; %endif local_barrier(); @@ -325,38 +337,40 @@ void ${name_prefix}_scan_intervals( <% scan_offset = 1 %> - %if is_segmented: - index_type first_segment_start_in_subtree; - %endif - % while scan_offset <= wg_size: // {{{ reads from local allowed, writes to local not allowed - if ( - LID_0 >= ${scan_offset} - % if is_tail: - && K*LID_0 < offset_end - % endif - ) + if (LID_0 >= ${scan_offset}) { scan_type tmp = ldata[K][LID_0 - ${scan_offset}]; - %if is_segmented: - if (l_first_segment_start_in_k_group[LID_0] == NO_SEG_BOUNDARY) + % if is_tail: + if (K*LID_0 < offset_end) + % endif + { + %if is_segmented: + if (l_first_segment_start_in_subtree[LID_0] == NO_SEG_BOUNDARY) + val = SCAN_EXPR(tmp, val); + %else: val = SCAN_EXPR(tmp, val); + %endif + } + + %if is_segmented: + // Prepare for l_first_segment_start_in_subtree, below. - // update l_first_segment_start_in_k_group regardless - segment_start_in_subtree = min( - l_first_segment_start_in_k_group[LID_0], - l_first_segment_start_in_k_group[LID_0 - ${scan_offset}]); - %else: - val = SCAN_EXPR(tmp, val); + // Note that this update must take place *even* if we're + // out of bounds. + + first_segment_start_in_subtree = min( + l_first_segment_start_in_subtree[LID_0], + l_first_segment_start_in_subtree[LID_0 - ${scan_offset}]); %endif } %if is_segmented: else { first_segment_start_in_subtree = - l_first_segment_start_in_k_group[LID_0]; + l_first_segment_start_in_subtree[LID_0]; } %endif @@ -368,13 +382,29 @@ void ${name_prefix}_scan_intervals( ldata[K][LID_0] = val; %if is_segmented: - segment_start_in_k_group[LID_0] = segment_start_in_subtree; + l_first_segment_start_in_subtree[LID_0] = + first_segment_start_in_subtree; %endif // }}} local_barrier(); + %if 0: + if (LID_0 == 0) + { + printf("${scan_offset}: "); + for (int i = 0; i < WG_SIZE; ++i) + { + if (l_first_segment_start_in_subtree[i] == NO_SEG_BOUNDARY) + printf("- "); + else + printf("%d ", l_first_segment_start_in_subtree[i]); + } + printf("\n"); + } + %endif + <% scan_offset *= 2 %> % endwhile @@ -408,11 +438,10 @@ void ${name_prefix}_scan_intervals( %if is_segmented: if (LID_0 == 0) { - // carry in from previous unit - first_segment_start_in_interval = - first_segment_start_in_interval - || - segment_start_in_k_group[WG_SIZE-1]; + // update interval-wide first-seg variable from current unit + first_segment_start_in_interval = min( + first_segment_start_in_interval, + l_first_segment_start_in_subtree[WG_SIZE-1]); } %endif @@ -459,7 +488,8 @@ void ${name_prefix}_scan_intervals( # {{{ local update -# used for inclusive scan +# used for inclusive scan, i.e. for output_statements that do not request +# look-behind, i.e. access to the preceding item. LOCAL_UPDATE_SOURCE = SHARED_PREAMBLE + r"""//CL// @@ -478,18 +508,9 @@ void ${name_prefix}_final_update( %endif ) { - %if neutral is None: - if (GID_0 == 0) - return; - %endif - const index_type interval_begin = interval_size * GID_0; const index_type interval_end = min(interval_begin + interval_size, N); - %if is_segmented: - interval_end = min(interval_end, g_first_segment_start_in_interval[GID_0]); - %endif - // value to add to this segment scan_type prev_group_sum; if (GID_0 == 0) @@ -497,13 +518,8 @@ void ${name_prefix}_final_update( else prev_group_sum = interval_results[GID_0 - 1]; - for(index_type unit_base = interval_begin; - unit_base < interval_end; - unit_base += WG_SIZE) - { - const index_type update_i = unit_base + LID_0; - - if(update_i < interval_end) + <%def name="update_loop(end)"> + for(; update_i < ${end}; update_i += WG_SIZE) { scan_type partial_val = partial_scan_buffer[update_i]; scan_type value = SCAN_EXPR(prev_group_sum, partial_val); @@ -512,7 +528,24 @@ void ${name_prefix}_final_update( OUTPUT_STMT(update_i, prev_item_unavailable_with_local_update, value); } - } + </%def> + + index_type update_i = interval_begin+LID_0; + + %if is_segmented: + // {{{ update to the first intra-interval segment boundary + const index_type first_seg_start_in_interval = + g_first_segment_start_in_interval[GID_0]; + + index_type seg_end = min(first_seg_start_in_interval, interval_end); + ${update_loop('seg_end')} + + prev_group_sum = ${neutral}; + + // }}} + %endif + + ${update_loop('interval_end')} } """ @@ -633,21 +666,22 @@ _PREFIX_WORDS = set(""" ldata partial_scan_buffer global scan_offset segment_start_in_k_group carry g_first_segment_start_in_interval IS_SEG_START tmp Z - val l_first_segment_start_in_k_group unit_size + val l_first_segment_start_in_subtree unit_size index_type interval_begin interval_size offset_end K - SCAN_EXPR do_update NO_SEG_BOUNDARY WG_SIZE + SCAN_EXPR do_update WG_SIZE first_segment_start_in_k_group scan_type segment_start_in_subtree offset interval_results interval_end first_segment_start_in_subtree unit_base first_segment_start_in_interval k INPUT_EXPR - prev_group_sum prev pv value n partial_val pgs OUTPUT_STMT + prev_group_sum prev pv value partial_val pgs OUTPUT_STMT is_seg_start update_i scan_item_at_i seq_i read_i - l_ + l_ o_mod_k o_div_k l_is_segment_start scan_value sum + first_seg_start_in_interval """.split()) _IGNORED_WORDS = set(""" typedef for endfor if void while endwhile endfor endif else const printf - None return bool + None return bool n set iteritems len setdefault @@ -666,7 +700,7 @@ _IGNORED_WORDS = set(""" positions all padded integer its previous write based true writes 0 has local worth scan_expr to read cannot not X items False bank - four beginning follows applicable sum item min each indices works side + four beginning follows applicable item min each indices works side scanning right summed relative used id out index avoid current state boundary True across be This reads groups along Otherwise undetermined store of times prior s update first regardless Each number because @@ -678,17 +712,19 @@ _IGNORED_WORDS = set(""" A load B group perform shift tail see last OR this add fetched into are directly need gets them stenciled that undefined - there up any ones or name + there up any ones or name only relevant populated + even wide we Prepare int seg Note re below place take variable must + intra is_tail is_first_level input_expr argument_signature preamble double_support neutral output_statement index_type_max k_group_size name_prefix is_segmented index_ctype scan_ctype - wg_size is_i_segment_start_expr fetch_expr_offsets - arg_ctypes ife_offsets input_fetch_exprs - ife_offset arg_name local_fetch_expr_args + wg_size is_segment_start_expr fetch_expr_offsets + arg_ctypes ife_offsets input_fetch_exprs def + ife_offset arg_name local_fetch_expr_args update_body a b prev_item i prev_item_unavailable_with_local_update prev_value - N + N NO_SEG_BOUNDARY """.split()) def _make_template(s): @@ -725,7 +761,7 @@ class _ScanKernelInfo(Record): class GenericScanKernel(object): def __init__(self, ctx, dtype, arguments, input_expr, scan_expr, neutral, output_statement, - is_i_segment_start_expr=None, input_fetch_exprs=[], + is_segment_start_expr=None, input_fetch_exprs=[], partial_scan_buffer_name=None, index_dtype=np.int32, name_prefix="scan", options=[], preamble="", devices=None): @@ -751,7 +787,7 @@ class GenericScanKernel(object): Note that *prev_item enables the construction of an exclusive scan. - :arg is_i_segment_start_expr: If given, makes the scan a segmented + :arg is_segment_start_expr: If given, makes the scan a segmented scan. Has access to the current index `i` and the input element as `a` and returns a bool. If it returns true, then previous sums will not spill over into the item with index i. @@ -799,9 +835,9 @@ class GenericScanKernel(object): i for i, arg in enumerate(self.parsed_args) if isinstance(arg, VectorArg)][0] - self.is_segmented = is_i_segment_start_expr is not None + self.is_segmented = is_segment_start_expr is not None if self.is_segmented: - is_i_segment_start_expr = is_i_segment_start_expr.replace("\n", " ") + is_segment_start_expr = is_segment_start_expr.replace("\n", " ") use_lookbehind_update = "prev_item" in output_statement @@ -864,7 +900,7 @@ class GenericScanKernel(object): while True: candidate_scan_info = self.build_scan_kernel( max_scan_wg_size, arguments, input_expr, - is_i_segment_start_expr, + is_segment_start_expr, input_fetch_exprs=input_fetch_exprs, is_first_level=True) @@ -901,14 +937,14 @@ class GenericScanKernel(object): "__global %s *g_first_segment_start_in_interval_input" % dtype_to_ctype(self.index_dtype)) - # is_i_segment_start_expr answers the question "should previous sums + # is_segment_start_expr answers the question "should previous sums # spill over into this item". And since g_first_segment_start_in_interval_input # answers the question if a segment boundary was found in an interval of data, # then if not, it's ok to spill over. - second_level_build_kwargs["is_i_segment_start_expr"] = \ + second_level_build_kwargs["is_segment_start_expr"] = \ "g_first_segment_start_in_interval_input[i] != NO_SEG_BOUNDARY" else: - second_level_build_kwargs["is_i_segment_start_expr"] = None + second_level_build_kwargs["is_segment_start_expr"] = None self.second_level_scan_info = self.build_scan_kernel( max_scan_wg_size, @@ -940,22 +976,26 @@ class GenericScanKernel(object): wg_size=self.update_wg_size, output_statement=output_statement.replace("\n", " "), argument_signature=arguments.replace("\n", " "), - is_i_segment_start_expr=is_i_segment_start_expr, + is_segment_start_expr=is_segment_start_expr, input_expr=input_expr.replace("\n", " "), **self.code_variables)) + with open("update.cl", "wt") as f: f.write(final_update_src) final_update_prg = cl.Program(self.context, final_update_src).build(options) self.final_update_knl = getattr( final_update_prg, name_prefix+"_final_update") - self.final_update_knl.set_scalar_arg_dtypes( + update_scalar_arg_dtypes = ( _get_scalar_arg_dtypes(self.parsed_args) + [self.index_dtype, self.index_dtype, None, None]) + if self.is_segmented: + update_scalar_arg_dtypes.append(None) # g_first_segment_start_in_interval + self.final_update_knl.set_scalar_arg_dtypes(update_scalar_arg_dtypes) # }}} def build_scan_kernel(self, max_wg_size, arguments, input_expr, - is_i_segment_start_expr, input_fetch_exprs, is_first_level): + is_segment_start_expr, input_fetch_exprs, is_first_level): scalar_arg_dtypes = _get_scalar_arg_dtypes(_parse_args(arguments)) # Thrust says that 128 is big enough for GT200 @@ -972,17 +1012,19 @@ class GenericScanKernel(object): k_group_size = 8 scan_tpl = _make_template(SCAN_INTERVALS_SOURCE) - scan_intervals_src = str(scan_tpl.render( + scan_src = str(scan_tpl.render( wg_size=wg_size, input_expr=input_expr, k_group_size=k_group_size, argument_signature=arguments.replace("\n", " "), - is_i_segment_start_expr=is_i_segment_start_expr, + is_segment_start_expr=is_segment_start_expr, input_fetch_exprs=input_fetch_exprs, is_first_level=is_first_level, **self.code_variables)) - prg = cl.Program(self.context, scan_intervals_src).build(self.options) + with open("scan-lev%d.cl" % (1 if is_first_level else 2), "wt") as f: f.write(scan_src) + + prg = cl.Program(self.context, scan_src).build(self.options) knl = getattr( prg, @@ -1041,30 +1083,31 @@ class GenericScanKernel(object): # {{{ first level scan of interval (one interval per block) - interval_results = allocator(self.dtype.itemsize*num_intervals) + interval_results = cl_array.empty(queue, + num_intervals, dtype=self.dtype, + allocator=allocator) if self.partial_scan_buffer_idx is None: - partial_scan_buffer = allocator(n*self.dtype.itemsize) + partial_scan_buffer = cl_array.empty( + queue, n, dtype=self.dtype, + allocator=allocator) else: - partial_scan_buffer = data_args[self.partial_scan_buffer_idx] + partial_scan_buffer = args[self.partial_scan_buffer_idx] scan1_args = data_args + [ - partial_scan_buffer, n, interval_size, interval_results, + partial_scan_buffer.data, n, interval_size, interval_results.data, ] - if self.code_variables["is_segmented"]: - first_segment_start_in_interval = allocator(self.index_dtype.itemsize*num_intervals) - scan1_args = scan1_args + (first_segment_start_in_interval,) + if self.is_segmented: + first_segment_start_in_interval = cl_array.empty(queue, + num_intervals, dtype=self.index_dtype, + allocator=allocator) + scan1_args = scan1_args + [first_segment_start_in_interval.data] l1_info.kernel( queue, (num_intervals,), (l1_info.wg_size,), *scan1_args, **dict(g_times_l=True)) - if 0: - psb_host = np.empty(n, self.dtype) - cl.enqueue_copy(queue, psb_host, partial_scan_buffer) - print "PSB", psb_host - # }}} # {{{ second level inclusive scan of per-interval results @@ -1072,10 +1115,14 @@ class GenericScanKernel(object): # can scan at most one interval assert interval_size >= num_intervals - scan2_args = data_args + [interval_results, interval_results] + scan2_args = data_args + [ + interval_results.data, # interval_sums + ] if self.is_segmented: - scan2_args = scan2_args + [first_segment_start_in_interval] - scan2_args = scan2_args + [num_intervals, interval_size] + scan2_args.append(first_segment_start_in_interval.data) + scan2_args = scan2_args + [ + interval_results.data, # partial_scan_buffer + num_intervals, interval_size] l2_info.kernel( queue, (1,), (l1_info.wg_size,), @@ -1085,9 +1132,10 @@ class GenericScanKernel(object): # {{{ update intervals with result of interval scan - upd_args = data_args + [n, interval_size, interval_results, partial_scan_buffer] + upd_args = data_args + [ + n, interval_size, interval_results.data, partial_scan_buffer.data] if self.is_segmented: - upd_args = upd_args.append(first_segment_start_in_interval) + upd_args.append(first_segment_start_in_interval.data) self.final_update_knl( queue, (num_intervals,), (self.update_wg_size,), diff --git a/test/test_array.py b/test/test_array.py index b4d2aa63..7e88fdfe 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -761,9 +761,15 @@ def summarize_error(obtained, desired, orig, thresh=1e-5): scan_test_counts = [ 10, + 2 ** 8 - 1, + 2 ** 8, + 2 ** 8 + 1, 2 ** 10 - 5, 2 ** 10, 2 ** 10 + 5, + 2 ** 12 - 5, + 2 ** 12, + 2 ** 12 + 5, 2 ** 20 - 2 ** 18, 2 ** 20 - 2 ** 18 + 5, 2 ** 20 + 1, @@ -872,20 +878,78 @@ def test_segmented_scan(ctx_factory): context = ctx_factory() queue = cl.CommandQueue(context) + from pyopencl.tools import dtype_to_ctype + dtype = np.int32 + ctype = dtype_to_ctype(dtype) + + from pyopencl.scan import GenericScanKernel + knl = GenericScanKernel(context, dtype, + arguments="__global %s *ary, __global char *segflags, __global %s *out" + % (ctype, ctype), + input_expr="ary[i]", + scan_expr="a+b", neutral="0", + is_segment_start_expr="segflags[i]", + output_statement="out[i] = item", + options=["-g", "-O0"]) + + np.set_printoptions(threshold=2000) from random import randrange from pyopencl.clrandom import rand as clrand for n in scan_test_counts: - a_dev = clrand(queue, (n,), dtype=np.int32, a=0, b=1000) + a_dev = clrand(queue, (n,), dtype=dtype, a=0, b=10) a = a_dev.get() - seg_boundary_count = min(100, randrange(0, int(0.4*n))) - seg_boundaries = np.fromiter(sorted(randrange(n) for i in xrange(seg_boundary_count)), - dtype=np.intp) - print seg_boundaries + if 10 <= n < 20: + seg_boundaries_values = [ + [0, 9], + [0, 3], + [4, 6], + ] + else: + seg_boundaries_values = [] + for i in range(10): + seg_boundary_count = max(2, min(100, randrange(0, int(0.4*n)))) + seg_boundaries = [randrange(n) for i in xrange(seg_boundary_count)] + if n >= 1029: + seg_boundaries.insert(0, 1028) + seg_boundaries.sort() + seg_boundaries_values.append(seg_boundaries) + + for seg_boundaries in seg_boundaries_values: + #print "BOUNDARIES", seg_boundaries + + seg_boundary_flags = np.zeros(n, dtype=np.uint8) + seg_boundary_flags[seg_boundaries] = 1 + seg_boundary_flags_dev = cl_array.to_device(queue, seg_boundary_flags) + + seg_boundaries.insert(0, 0) + + result_host = a.copy() + for i, seg_start in enumerate(seg_boundaries): + if i+1 < len(seg_boundaries): + slc = slice(seg_start, seg_boundaries[i+1]) + else: + slc = slice(seg_start, None) + + result_host[slc] = np.cumsum(result_host[slc]) + + #print "REF", result_host + + result_dev = cl_array.empty_like(a_dev) + knl(a_dev, seg_boundary_flags_dev, result_dev) + + #print "RES", result_dev + is_correct = (result_dev.get() == result_host).all() + if not is_correct: + diff = result_dev.get() - result_host + print "RES-REF", diff + print("ERRWHERE", np.where(diff)) + print(n, list(seg_boundaries_values)) + + assert is_correct + print(n, "done") + - seg_boundary_flags = np.zeros(n, dtype=np.uint8) - seg_boundary_flags[seg_boundaries] = 1 - seg_boundary_flags_dev = cl_array.to_device(queue, seg_boundary_flags) # }}} -- GitLab