From 58356197ff0f2ef9d288357b60c8a7387512fed0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 29 Jul 2012 20:40:17 -0400 Subject: [PATCH] Shift responsibility for segmentation to scan_expr. --- pyopencl/scan.py | 139 ++++++++++++++++++++++++++------------------- test/test_array.py | 5 +- 2 files changed, 84 insertions(+), 60 deletions(-) diff --git a/pyopencl/scan.py b/pyopencl/scan.py index d4744e5c..6565682f 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -47,11 +47,7 @@ from pyopencl.tools import context_dependent_memoize SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL// #define WG_SIZE ${wg_size} -/* SCAN_EXPR has no right know the indices it is scanning at because -each index may occur an undetermined number of times in the scan tree, -and thus index-based side computations cannot be meaningful. */ - -#define SCAN_EXPR(a, b) ${scan_expr} +#define SCAN_EXPR(a, b, across_seg_boundary) ${scan_expr} #define INPUT_EXPR(i) (${input_expr}) %if is_segmented: #define IS_SEG_START(i, a) (${is_segment_start_expr}) @@ -272,13 +268,15 @@ void ${name_prefix}_scan_intervals( first_segment_start_in_k_group = unit_base + K*LID_0; %endif - if (LID_0 == 0 && unit_base != interval_begin - %if is_segmented: - && !l_segment_start_flags[0][0] - %endif - ) + if (LID_0 == 0 && unit_base != interval_begin) { - ldata[0][0] = SCAN_EXPR(ldata[K][WG_SIZE - 1], ldata[0][0]); + ldata[0][0] = SCAN_EXPR(ldata[K][WG_SIZE - 1], ldata[0][0], + %if is_segmented: + (l_segment_start_flags[0][0]) + %else: + false + %endif + ); } // }}} @@ -308,11 +306,16 @@ void ${name_prefix}_scan_intervals( first_segment_start_in_k_group = min( first_segment_start_in_k_group, seq_i); - sum = tmp; } - else %endif - sum = SCAN_EXPR(sum, tmp); + + sum = SCAN_EXPR(sum, tmp, + %if is_segmented: + (l_segment_start_flags[k][LID_0]) + %else: + false + %endif + ); ldata[k][LID_0] = sum; } @@ -357,12 +360,13 @@ void ${name_prefix}_scan_intervals( if (K*LID_0 < offset_end) % endif { - %if is_segmented: - if (l_first_segment_start_in_subtree[LID_0] == NO_SEG_BOUNDARY) - val = SCAN_EXPR(tmp, val); - %else: - val = SCAN_EXPR(tmp, val); - %endif + val = SCAN_EXPR(tmp, val, + %if is_segmented: + (l_first_segment_start_in_subtree[LID_0] != NO_SEG_BOUNDARY) + %else: + false + %endif + ); } %if is_segmented: @@ -428,19 +432,19 @@ void ${name_prefix}_scan_intervals( for(index_type k = 0; k < K; k++) { - bool do_update = true; %if is_tail: - do_update = K * LID_0 + k < offset_end; + if (K * LID_0 + k < offset_end) %endif - %if is_segmented: - do_update = unit_base + K * LID_0 + k - < first_segment_start_in_k_group; - %endif - - if (do_update) { scan_type tmp = ldata[k][LID_0]; - ldata[k][LID_0] = SCAN_EXPR(sum, tmp); + ldata[k][LID_0] = SCAN_EXPR(sum, tmp, + %if is_segmented: + (unit_base + K * LID_0 + k + >= first_segment_start_in_k_group) + %else: + false + %endif + ); } } } @@ -542,28 +546,24 @@ void ${name_prefix}_final_update( index_type update_i = interval_begin+LID_0; - <%def name="update_loop_plain(end)"> - for(; update_i < ${end}; update_i += WG_SIZE) - { - scan_type partial_val = partial_scan_buffer[update_i]; - scan_type item = SCAN_EXPR(carry, partial_val); - index_type i = update_i; - - { ${output_statement}; } - } - </%def> - - // {{{ update to the first intra-interval segment boundary - %if is_segmented: index_type seg_end = min(first_seg_start_in_interval, interval_end); - ${update_loop_plain("seg_end")} - carry = ${neutral}; %endif - // }}} + for(; update_i < interval_end; update_i += WG_SIZE) + { + scan_type partial_val = partial_scan_buffer[update_i]; + scan_type item = SCAN_EXPR(carry, partial_val, + %if is_segmented: + (update_i >= seg_end) + %else: + false + %endif + ); + index_type i = update_i; - ${update_loop_plain("interval_end")} + { ${output_statement}; } + } // }}} %else: @@ -585,10 +585,13 @@ void ${name_prefix}_final_update( { scan_type tmp = partial_scan_buffer[update_i]; - %if is_segmented: - if (update_i < first_seg_start_in_interval) - %endif - { tmp = SCAN_EXPR(carry, tmp); } + tmp = SCAN_EXPR(carry, tmp, + %if is_segmented: + (update_i >= first_seg_start_in_interval) + %else: + false + %endif + ); ldata[LID_0] = tmp; @@ -703,7 +706,7 @@ _PREFIX_WORDS = set(""" _IGNORED_WORDS = set(""" typedef for endfor if void while endwhile endfor endif else const printf - None return bool n char + None return bool n char true false set iteritems len setdefault @@ -716,7 +719,7 @@ _IGNORED_WORDS = set(""" _final_update _scan_intervals _debug_scan - positions all padded integer its previous write based true writes 0 + positions all padded integer its previous write based writes 0 has local worth scan_expr to read cannot not X items False bank four beginning follows applicable item min each indices works side scanning right summed relative used id out index avoid current state @@ -747,7 +750,7 @@ _IGNORED_WORDS = set(""" update_loop first_seg a b prev_item i prev_item_unavailable_with_local_update prev_value - N NO_SEG_BOUNDARY + N NO_SEG_BOUNDARY across_seg_boundary """.split()) def _make_template(s): @@ -807,6 +810,21 @@ class _GenericScanKernelBase(object): modified by the scan, it should live in `b`. This expression may call functions given in the *preamble*. + + Another value available to this expression is `across_seg_boundary`, + a C `bool` indicating whether this scan update is crossing a + segment boundary, as defined by `is_segment_start_expr`. + The scan routine does not implement segmentation + semantics on its own. It relies on `scan_expr` to do this. + This value is available (but always `false`) even for a + non-segmented scan. + + .. note:: + + In early pre-releases of the segmented scan, + segmentation semantics were implemented *without* + relying on `scan_expr`. + :arg input_expr: A C expression, encoded as a string, resulting in the values to which the scan is applied. This may be used to apply a mapping to values stored in *arguments* before being @@ -1230,13 +1248,16 @@ void ${name_prefix}_debug_scan( prev_item = item; %if is_segmented: - { - bool is_seg_start = IS_SEG_START(i, my_val); - if (is_seg_start) - prev_item = ${neutral}; - } + bool is_seg_start = IS_SEG_START(i, my_val); %endif - item = SCAN_EXPR(prev_item, my_val); + + item = SCAN_EXPR(prev_item, my_val, + %if is_segmented: + is_seg_start + %else: + false + %endif + ); { ${output_statement}; diff --git a/test/test_array.py b/test/test_array.py index bef04fb9..694f5a51 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -957,7 +957,7 @@ def test_segmented_scan(ctx_factory): arguments="__global %s *ary, __global char *segflags, __global %s *out" % (ctype, ctype), input_expr="ary[i]", - scan_expr="a+b", neutral="0", + scan_expr="across_seg_boundary ? b : (a+b)", neutral="0", is_segment_start_expr="segflags[i]", output_statement=output_statement, options=[]) @@ -1117,6 +1117,9 @@ def test_view(ctx_factory): @pytools.test.mark_test.opencl def no_test_slice(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + from pyopencl.clrandom import rand as clrand l = 20000 -- GitLab