diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 9173f5de8fd280083392b2614ce6a9374e40edfb..0aaabb91e33e2d5971bb58d01ec6acd52f0cd022 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -61,6 +61,9 @@ ${preamble} typedef ${scan_ctype} scan_type; typedef ${index_ctype} index_type; + +// NO_SEG_BOUNDARY is the largest representable integer in index_type. +// This assumption is used in code below. #define NO_SEG_BOUNDARY ${index_type_max} """ @@ -108,8 +111,8 @@ typedef ${index_ctype} index_type; # across intervals exchanges data globally, and the final update # adds the exchanged sums to each interval. # -# Exclusive scan is realized by performing a right-shift inside -# the final update. +# Exclusive scan is realized by allowing look-behind (access to the +# preceding item) in the final update, by means of a local shift. SCAN_INTERVALS_SOURCE = SHARED_PREAMBLE + r"""//CL// @@ -127,11 +130,13 @@ void ${name_prefix}_scan_intervals( %endif %if is_segmented and is_first_level: /* NO_SEG_BOUNDARY if no segment boundary in interval. - NO_SEG_BOUNDARY is the largest representable integer in index_type. Otherwise, index relative to interval beginning. */ , GLOBAL_MEM index_type *g_first_segment_start_in_interval %endif + %if store_segment_start_flags: + , GLOBAL_MEM index_type *g_segment_start_flags + %endif ) { // padded in WG_SIZE to avoid bank conflicts @@ -139,7 +144,7 @@ void ${name_prefix}_scan_intervals( LOCAL_MEM scan_type ldata[K + 1][WG_SIZE + 1]; %if is_segmented: - LOCAL_MEM bool l_is_segment_start[K][WG_SIZE]; + LOCAL_MEM bool l_segment_start_flags[K][WG_SIZE]; LOCAL_MEM index_type l_first_segment_start_in_subtree[WG_SIZE]; // only relevant/populated for local id 0 @@ -245,7 +250,11 @@ void ${name_prefix}_scan_intervals( ldata[o_mod_k][offset / K] = scan_value; %if is_segmented: - l_is_segment_start[o_mod_k][o_div_k] = IS_SEG_START(read_i, scan_value); + bool is_seg_start = IS_SEG_START(read_i, scan_value); + l_segment_start_flags[o_mod_k][o_div_k] = is_seg_start; + %endif + %if store_segment_start_flags: + g_segment_start_flags[read_i] = is_seg_start; %endif } } @@ -258,13 +267,13 @@ void ${name_prefix}_scan_intervals( local_barrier(); first_segment_start_in_k_group = NO_SEG_BOUNDARY; - if (l_is_segment_start[0][LID_0]) + if (l_segment_start_flags[0][LID_0]) first_segment_start_in_k_group = unit_base + K*LID_0; %endif if (LID_0 == 0 && unit_base != interval_begin %if is_segmented: - && !l_is_segment_start[0][0] + && !l_segment_start_flags[0][0] %endif ) { @@ -293,7 +302,7 @@ void ${name_prefix}_scan_intervals( index_type seq_i = unit_base + K*LID_0 + k; %if is_segmented: - if (l_is_segment_start[k][LID_0]) + if (l_segment_start_flags[k][LID_0]) { first_segment_start_in_k_group = min( first_segment_start_in_k_group, @@ -486,12 +495,9 @@ void ${name_prefix}_scan_intervals( # }}} -# {{{ local update - -# used for inclusive scan, i.e. for output_statements that do not request -# look-behind, i.e. access to the preceding item. +# {{{ update -LOCAL_UPDATE_SOURCE = SHARED_PREAMBLE + r"""//CL// +UPDATE_SOURCE = SHARED_PREAMBLE + r"""//CL// #define OUTPUT_STMT(i, prev_item, item) { ${output_statement}; } @@ -506,129 +512,119 @@ void ${name_prefix}_final_update( %if is_segmented: , GLOBAL_MEM index_type *g_first_segment_start_in_interval %endif + %if is_segmented and use_lookbehind_update: + , GLOBAL_MEM bool *g_segment_start_flags + %endif ) { + %if use_lookbehind_update: + LOCAL_MEM scan_type ldata[WG_SIZE]; + %endif + %if is_segmented and use_lookbehind_update: + LOCAL_MEM scan_type l_segment_start_flags[WG_SIZE]; + %endif + const index_type interval_begin = interval_size * GID_0; const index_type interval_end = min(interval_begin + interval_size, N); - // value to add to this segment - scan_type prev_group_sum; - if (GID_0 == 0) - prev_group_sum = ${neutral}; - else - prev_group_sum = interval_results[GID_0 - 1]; - - <%def name="update_loop(end)"> - for(; update_i < ${end}; update_i += WG_SIZE) - { - scan_type partial_val = partial_scan_buffer[update_i]; - scan_type value = SCAN_EXPR(prev_group_sum, partial_val); + // carry from last interval + scan_type carry = ${neutral}; + if (GID_0 != 0) + carry = interval_results[GID_0 - 1]; - // printf("i: %d pgs: %d pv: %d val: %d\n", update_i, prev_group_sum, partial_val, value); + %if not use_lookbehind_update: + // {{{ no look-behind ('prev_item' not in output_statement -> simpler) - OUTPUT_STMT(update_i, prev_item_unavailable_with_local_update, value); - } - + index_type update_i = interval_begin+LID_0; - index_type update_i = interval_begin+LID_0; - - %if is_segmented: - // {{{ update to the first intra-interval segment boundary - const index_type first_seg_start_in_interval = - g_first_segment_start_in_interval[GID_0]; + <%def name="update_loop_plain(end, phase)"> + for(; update_i < ${end}; update_i += WG_SIZE) + { + scan_type partial_val = partial_scan_buffer[update_i]; + scan_type value = SCAN_EXPR(carry, partial_val); - index_type seg_end = min(first_seg_start_in_interval, interval_end); - ${update_loop('seg_end')} + OUTPUT_STMT(update_i, prev_item_unavailable_with_local_update, value); + } + - prev_group_sum = ${neutral}; + <% update_loop = self.update_loop_plain %> // }}} - %endif + %else: + // {{{ allow look-behind ('prev_item' in output_statement -> complicated) - ${update_loop('interval_end')} -} -""" + // We are not allowed to branch across barriers at a granularity smaller + // than the whole workgroup. Therefore, the for loop is group-global, + // and there are lots of local ifs. -# }}} + index_type group_base = interval_begin; + scan_type prev_value = carry; // (A) -# {{{ lookbehind update + <%def name="update_loop_lookbehind(end, phase)"> + for(; group_base < ${end}; group_base += WG_SIZE) + { + index_type update_i = group_base+LID_0; -# used for exclusive scan or output_statements that request look-behind, i.e. -# access to the preceding item. + // load a work group's worth of data + if (update_i < ${end}) + { + scan_type tmp = partial_scan_buffer[update_i]; + ldata[LID_0] = SCAN_EXPR(carry, tmp); + %if is_segmented: + l_segment_start_flags[LID_0] = g_segment_start_flags[update_i]; + %endif + } -LOOKBEHIND_UPDATE_SOURCE = SHARED_PREAMBLE + """//CL// + local_barrier(); -#define OUTPUT_STMT(i, prev_item, item) { ${output_statement}; } + // find prev_value + if (LID_0 != 0) + prev_value = ldata[LID_0 - 1]; + /* + else + prev_value = carry (see (A)) OR last tail (see (B)); + */ -KERNEL -REQD_WG_SIZE(WG_SIZE, 1, 1) -void ${name_prefix}_final_update( - ${argument_signature}, - const index_type N, - const index_type interval_size, - GLOBAL_MEM scan_type *interval_results, - GLOBAL_MEM scan_type *partial_scan_buffer - %if is_segmented: - , GLOBAL_MEM index_type *g_first_segment_start_in_interval - %endif - ) -{ - LOCAL_MEM scan_type ldata[WG_SIZE]; + if (update_i < ${end}) + { + %if is_segmented: + if (l_segment_start_flags[LID_0]) + prev_value = ${neutral}; + %endif - const index_type interval_begin = interval_size * GID_0; - const index_type interval_end = min(interval_begin + interval_size, N); + scan_type value = ldata[LID_0]; + OUTPUT_STMT(update_i, prev_value, value) + } - // value to add to this segment - scan_type carry = ${neutral}; - if(GID_0 != 0) - carry = interval_results[GID_0 - 1]; + if (LID_0 == 0) + prev_value = ldata[WG_SIZE - 1]; // (B) - scan_type prev_value = carry; // (A) + local_barrier(); + } + - for (index_type unit_base = interval_begin; - unit_base < interval_end; - unit_base += WG_SIZE) - { - const index_type update_i = unit_base + LID_0; + // FIXME TAKE CARE OF BLOCK RESYNCING - // load a work group's worth of data - if (update_i < interval_end) - { - scan_type tmp = partial_scan_buffer[update_i]; - ldata[LID_0] = SCAN_EXPR(carry, tmp); - } + <% update_loop = self.update_loop_lookbehind %> - local_barrier(); + // }}} + %endif - // perform right shift - if (LID_0 != 0) - prev_value = ldata[LID_0 - 1]; - /* - else - prev_value = carry (see (A)) OR last tail (see (B)); - */ + %if is_segmented: + // {{{ update to the first intra-interval segment boundary - %if is_segmented: - { - scan_type scan_item_at_i = INPUT_EXPR(update_i) - bool is_seg_start = IS_SEG_START(update_i, scan_item_at_i); - if (is_seg_start) - prev_value = ${neutral}; - } - %endif + const index_type first_seg_start_in_interval = + g_first_segment_start_in_interval[GID_0]; - if (update_i < interval_end) - { - scan_type value = ldata[LID_0]; + index_type seg_end = min(first_seg_start_in_interval, interval_end); + ${update_loop("seg_end", phase="first_seg")} - OUTPUT_STMT(update_i, prev_value, value) - } + carry = ${neutral}; - if(LID_0 == 0) - prev_value = ldata[WG_SIZE - 1]; // (B) + // }}} + %endif - local_barrier(); - } + ${update_loop("interval_end", phase="remainder")} } """ @@ -675,8 +671,9 @@ _PREFIX_WORDS = set(""" first_segment_start_in_interval k INPUT_EXPR prev_group_sum prev pv value partial_val pgs OUTPUT_STMT is_seg_start update_i scan_item_at_i seq_i read_i - l_ o_mod_k o_div_k l_is_segment_start scan_value sum - first_seg_start_in_interval + l_ o_mod_k o_div_k l_segment_start_flags scan_value sum + first_seg_start_in_interval g_segment_start_flags + group_base seg_end """.split()) _IGNORED_WORDS = set(""" @@ -714,7 +711,9 @@ _IGNORED_WORDS = set(""" gets them stenciled that undefined there up any ones or name only relevant populated even wide we Prepare int seg Note re below place take variable must - intra + intra Therefore find code assumption + branch workgroup complicated granularity phase remainder than simpler + We smaller look ifs lots self behind allow barriers whole loop is_tail is_first_level input_expr argument_signature preamble double_support neutral output_statement index_type_max @@ -722,6 +721,9 @@ _IGNORED_WORDS = set(""" wg_size is_segment_start_expr fetch_expr_offsets arg_ctypes ife_offsets input_fetch_exprs def ife_offset arg_name local_fetch_expr_args update_body + update_loop_lookbehind update_loop_plain update_loop + use_lookbehind_update store_segment_start_flags + update_loop first_seg a b prev_item i prev_item_unavailable_with_local_update prev_value N NO_SEG_BOUNDARY @@ -759,10 +761,11 @@ class _ScanKernelInfo(Record): # }}} class GenericScanKernel(object): + # {{{ constructor + def __init__(self, ctx, dtype, arguments, input_expr, scan_expr, neutral, output_statement, is_segment_start_expr=None, input_fetch_exprs=[], - partial_scan_buffer_name=None, index_dtype=np.int32, name_prefix="scan", options=[], preamble="", devices=None): """ @@ -840,20 +843,7 @@ class GenericScanKernel(object): is_segment_start_expr = is_segment_start_expr.replace("\n", " ") use_lookbehind_update = "prev_item" in output_statement - - if self.is_segmented and use_lookbehind_update: - # The final update in segmented exclusive scan must be able to - # reconstruct where the segment boundaries were, and therefore - # can't overwrite any of the input. - partial_scan_buffer_name = None - - if self.input_fetch_exprs: - # FIXME need to insert code to handle input_fetch_exprs into - # the lookbehind update - raise NotImplementedError("input_fetch_exprs are not supported " - "with a segmented scan using a look-behind update " - "(e.g. an exclusive scan)") - + self.store_segment_start_flags = self.is_segmented and use_lookbehind_update for name, arg_name, ife_offset in input_fetch_exprs: if ife_offset not in [0, -1]: @@ -863,13 +853,6 @@ class GenericScanKernel(object): for arg in self.parsed_args: arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype) - if partial_scan_buffer_name is not None: - self.partial_scan_buffer_idx, = [ - i for i, arg in enumerate(self.parsed_args) - if arg.name == partial_scan_buffer_name] - else: - self.partial_scan_buffer_idx = None - # {{{ set up shared code dict from pytools import all @@ -902,7 +885,8 @@ class GenericScanKernel(object): max_scan_wg_size, arguments, input_expr, is_segment_start_expr, input_fetch_exprs=input_fetch_exprs, - is_first_level=True) + is_first_level=True, + store_segment_start_flags=self.store_segment_start_flags) # Will this device actually let us execute this kernel # at the desired work group size? Building it is the @@ -952,6 +936,7 @@ class GenericScanKernel(object): input_expr="interval_sums[i]", input_fetch_exprs=[], is_first_level=False, + store_segment_start_flags=False, **second_level_build_kwargs) assert min( @@ -966,18 +951,14 @@ class GenericScanKernel(object): self.update_wg_size = min(max_scan_wg_size, 256) - if use_lookbehind_update: - update_src = LOOKBEHIND_UPDATE_SOURCE - else: - update_src = LOCAL_UPDATE_SOURCE - - final_update_tpl = _make_template(update_src) + final_update_tpl = _make_template(UPDATE_SOURCE) final_update_src = str(final_update_tpl.render( wg_size=self.update_wg_size, output_statement=output_statement.replace("\n", " "), argument_signature=arguments.replace("\n", " "), is_segment_start_expr=is_segment_start_expr, input_expr=input_expr.replace("\n", " "), + use_lookbehind_update=use_lookbehind_update, **self.code_variables)) with open("update.cl", "wt") as f: f.write(final_update_src) @@ -990,12 +971,19 @@ class GenericScanKernel(object): + [self.index_dtype, self.index_dtype, None, None]) if self.is_segmented: update_scalar_arg_dtypes.append(None) # g_first_segment_start_in_interval + if self.store_segment_start_flags: + update_scalar_arg_dtypes.append(None) # g_segment_start_flags self.final_update_knl.set_scalar_arg_dtypes(update_scalar_arg_dtypes) # }}} + # }}} + + # {{{ scan kernel build + def build_scan_kernel(self, max_wg_size, arguments, input_expr, - is_segment_start_expr, input_fetch_exprs, is_first_level): + is_segment_start_expr, input_fetch_exprs, is_first_level, + store_segment_start_flags): scalar_arg_dtypes = _get_scalar_arg_dtypes(_parse_args(arguments)) # Thrust says that 128 is big enough for GT200 @@ -1020,6 +1008,7 @@ class GenericScanKernel(object): is_segment_start_expr=is_segment_start_expr, input_fetch_exprs=input_fetch_exprs, is_first_level=is_first_level, + store_segment_start_flags=store_segment_start_flags, **self.code_variables)) with open("scan-lev%d.cl" % (1 if is_first_level else 2), "wt") as f: f.write(scan_src) @@ -1036,11 +1025,15 @@ class GenericScanKernel(object): scalar_arg_dtypes.append(None) # interval_results if self.is_segmented and is_first_level: scalar_arg_dtypes.append(None) # g_first_segment_start_in_interval + if store_segment_start_flags: + scalar_arg_dtypes.append(None) # g_segment_start_flags knl.set_scalar_arg_dtypes(scalar_arg_dtypes) return _ScanKernelInfo( kernel=knl, wg_size=wg_size, knl=knl, k_group_size=k_group_size) + # }}} + def __call__(self, *args, **kwargs): # {{{ argument processing @@ -1078,21 +1071,24 @@ class GenericScanKernel(object): interval_size, num_intervals = uniform_interval_splitting( n, unit_size, max_intervals) - #print "n:%d interval_size: %d num_intervals: %d k_group_size:%d" % ( - #n, interval_size, num_intervals, l1_info.k_group_size) - - # {{{ first level scan of interval (one interval per block) + # {{{ allocate some buffers interval_results = cl_array.empty(queue, num_intervals, dtype=self.dtype, allocator=allocator) - if self.partial_scan_buffer_idx is None: - partial_scan_buffer = cl_array.empty( - queue, n, dtype=self.dtype, + partial_scan_buffer = cl_array.empty( + queue, n, dtype=self.dtype, + allocator=allocator) + + if self.store_segment_start_flags: + segment_start_flags = cl_array.empty( + queue, n, dtype=np.bool, allocator=allocator) - else: - partial_scan_buffer = args[self.partial_scan_buffer_idx] + + # }}} + + # {{{ first level scan of interval (one interval per block) scan1_args = data_args + [ partial_scan_buffer.data, n, interval_size, interval_results.data, @@ -1102,7 +1098,10 @@ class GenericScanKernel(object): first_segment_start_in_interval = cl_array.empty(queue, num_intervals, dtype=self.index_dtype, allocator=allocator) - scan1_args = scan1_args + [first_segment_start_in_interval.data] + scan1_args.append(first_segment_start_in_interval.data) + + if self.store_segment_start_flags: + scan1_args.append(segment_start_flags.data) l1_info.kernel( queue, (num_intervals,), (l1_info.wg_size,), @@ -1110,7 +1109,7 @@ class GenericScanKernel(object): # }}} - # {{{ second level inclusive scan of per-interval results + # {{{ second level scan of per-interval results # can scan at most one interval assert interval_size >= num_intervals @@ -1136,6 +1135,8 @@ class GenericScanKernel(object): n, interval_size, interval_results.data, partial_scan_buffer.data] if self.is_segmented: upd_args.append(first_segment_start_in_interval.data) + if self.store_segment_start_flags: + upd_args.append(segment_start_flags.data) self.final_update_knl( queue, (num_intervals,), (self.update_wg_size,), @@ -1160,7 +1161,6 @@ class _ScanKernelBase(GenericScanKernel): scan_expr=scan_expr, neutral=neutral, output_statement=self.ary_output_statement, - partial_scan_buffer_name="output_ary", options=options, preamble=preamble, devices=devices) def __call__(self, input_ary, output_ary=None, allocator=None, queue=None): diff --git a/test/test_array.py b/test/test_array.py index f4a5c391d494e9c7aef8909862447cf0b69d0700..f094ff2379858858f8a62786a8b18fce5e59bec9 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -909,72 +909,85 @@ def test_segmented_scan(ctx_factory): dtype = np.int32 ctype = dtype_to_ctype(dtype) - from pyopencl.scan import GenericScanKernel - knl = GenericScanKernel(context, dtype, - arguments="__global %s *ary, __global char *segflags, __global %s *out" - % (ctype, ctype), - input_expr="ary[i]", - scan_expr="a+b", neutral="0", - is_segment_start_expr="segflags[i]", - output_statement="out[i] = item", - options=["-g", "-O0"]) - - np.set_printoptions(threshold=2000) - from random import randrange - from pyopencl.clrandom import rand as clrand - for n in scan_test_counts: - a_dev = clrand(queue, (n,), dtype=dtype, a=0, b=10) - a = a_dev.get() - - if 10 <= n < 20: - seg_boundaries_values = [ - [0, 9], - [0, 3], - [4, 6], - ] + for is_exclusive in [False, True]: + #for is_exclusive in [True, False]: + if is_exclusive: + output_statement = "out[i] = prev_item" else: - seg_boundaries_values = [] - for i in range(10): - seg_boundary_count = max(2, min(100, randrange(0, int(0.4*n)))) - seg_boundaries = [randrange(n) for i in xrange(seg_boundary_count)] - if n >= 1029: - seg_boundaries.insert(0, 1028) - seg_boundaries.sort() - seg_boundaries_values.append(seg_boundaries) - - for seg_boundaries in seg_boundaries_values: - #print "BOUNDARIES", seg_boundaries - - seg_boundary_flags = np.zeros(n, dtype=np.uint8) - seg_boundary_flags[seg_boundaries] = 1 - seg_boundary_flags_dev = cl_array.to_device(queue, seg_boundary_flags) - - seg_boundaries.insert(0, 0) - - result_host = a.copy() - for i, seg_start in enumerate(seg_boundaries): - if i+1 < len(seg_boundaries): - slc = slice(seg_start, seg_boundaries[i+1]) - else: - slc = slice(seg_start, None) - - result_host[slc] = np.cumsum(result_host[slc]) - - #print "REF", result_host - - result_dev = cl_array.empty_like(a_dev) - knl(a_dev, seg_boundary_flags_dev, result_dev) - - #print "RES", result_dev - is_correct = (result_dev.get() == result_host).all() - if not is_correct: - diff = result_dev.get() - result_host - print "RES-REF", diff - print("ERRWHERE", np.where(diff)) - print(n, list(seg_boundaries_values)) - - assert is_correct - print(n, "done") + output_statement = "out[i] = item" + + from pyopencl.scan import GenericScanKernel + knl = GenericScanKernel(context, dtype, + arguments="__global %s *ary, __global char *segflags, __global %s *out" + % (ctype, ctype), + input_expr="ary[i]", + scan_expr="a+b", neutral="0", + is_segment_start_expr="segflags[i]", + output_statement=output_statement, + options=[]) + + np.set_printoptions(threshold=2000) + from random import randrange + from pyopencl.clrandom import rand as clrand + for n in scan_test_counts: + a_dev = clrand(queue, (n,), dtype=dtype, a=0, b=10) + a = a_dev.get() + + if 10 <= n < 20: + seg_boundaries_values = [ + [0, 9], + [0, 3], + [4, 6], + ] + else: + seg_boundaries_values = [] + for i in range(10): + seg_boundary_count = max(2, min(100, randrange(0, int(0.4*n)))) + seg_boundaries = [randrange(n) for i in xrange(seg_boundary_count)] + if n >= 1029: + seg_boundaries.insert(0, 1028) + seg_boundaries.sort() + seg_boundaries_values.append(seg_boundaries) + + for seg_boundaries in seg_boundaries_values: + #print "BOUNDARIES", seg_boundaries + + seg_boundary_flags = np.zeros(n, dtype=np.uint8) + seg_boundary_flags[seg_boundaries] = 1 + seg_boundary_flags_dev = cl_array.to_device(queue, seg_boundary_flags) + + seg_boundaries.insert(0, 0) + + result_host = a.copy() + for i, seg_start in enumerate(seg_boundaries): + if i+1 < len(seg_boundaries): + seg_end = seg_boundaries[i+1] + else: + seg_end = None + + if is_exclusive: + result_host[seg_start+1:seg_end] = np.cumsum( + result_host[seg_start:seg_end][:-1]) + result_host[seg_start] = 0 + else: + result_host[seg_start:seg_end] = np.cumsum( + result_host[seg_start:seg_end]) + + #print "REF", result_host + + result_dev = cl_array.empty_like(a_dev) + knl(a_dev, seg_boundary_flags_dev, result_dev) + + #print "RES", result_dev + is_correct = (result_dev.get() == result_host).all() + if not is_correct: + diff = result_dev.get() - result_host + print "RES-REF", diff + print("ERRWHERE", np.where(diff)) + print(n, list(seg_boundaries)) + + assert is_correct + print(n, "done")