From 539c1f990870091efd51597e9cba2c92e2a72a3f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 23 Jul 2012 13:29:37 -0500 Subject: [PATCH] Segmented scan passes all tests. --- pyopencl/scan.py | 134 +++++++++++++++++++++++---------------------- test/test_array.py | 16 +++--- 2 files changed, 79 insertions(+), 71 deletions(-) diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 0aaabb91..136c789c 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -45,6 +45,8 @@ from pyopencl.tools import context_dependent_memoize # {{{ preamble SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL// +#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable + #define WG_SIZE ${wg_size} /* SCAN_EXPR has no right know the indices it is scanning at because @@ -135,7 +137,7 @@ void ${name_prefix}_scan_intervals( , GLOBAL_MEM index_type *g_first_segment_start_in_interval %endif %if store_segment_start_flags: - , GLOBAL_MEM index_type *g_segment_start_flags + , GLOBAL_MEM char *g_segment_start_flags %endif ) { @@ -144,7 +146,7 @@ void ${name_prefix}_scan_intervals( LOCAL_MEM scan_type ldata[K + 1][WG_SIZE + 1]; %if is_segmented: - LOCAL_MEM bool l_segment_start_flags[K][WG_SIZE]; + LOCAL_MEM char l_segment_start_flags[K][WG_SIZE]; LOCAL_MEM index_type l_first_segment_start_in_subtree[WG_SIZE]; // only relevant/populated for local id 0 @@ -513,7 +515,7 @@ void ${name_prefix}_final_update( , GLOBAL_MEM index_type *g_first_segment_start_in_interval %endif %if is_segmented and use_lookbehind_update: - , GLOBAL_MEM bool *g_segment_start_flags + , GLOBAL_MEM char *g_segment_start_flags %endif ) { @@ -521,7 +523,7 @@ void ${name_prefix}_final_update( LOCAL_MEM scan_type ldata[WG_SIZE]; %endif %if is_segmented and use_lookbehind_update: - LOCAL_MEM scan_type l_segment_start_flags[WG_SIZE]; + LOCAL_MEM char l_segment_start_flags[WG_SIZE]; %endif const index_type interval_begin = interval_size * GID_0; @@ -532,12 +534,18 @@ void ${name_prefix}_final_update( if (GID_0 != 0) carry = interval_results[GID_0 - 1]; + %if is_segmented: + const index_type first_seg_start_in_interval = + g_first_segment_start_in_interval[GID_0]; + + %endif + %if not use_lookbehind_update: // {{{ no look-behind ('prev_item' not in output_statement -> simpler) index_type update_i = interval_begin+LID_0; - <%def name="update_loop_plain(end, phase)"> + <%def name="update_loop_plain(end)"> for(; update_i < ${end}; update_i += WG_SIZE) { scan_type partial_val = partial_scan_buffer[update_i]; @@ -547,7 +555,17 @@ void ${name_prefix}_final_update( } </%def> - <% update_loop = self.update_loop_plain %> + // {{{ update to the first intra-interval segment boundary + + %if is_segmented: + index_type seg_end = min(first_seg_start_in_interval, interval_end); + ${update_loop_plain("seg_end")} + carry = ${neutral}; + %endif + + // }}} + + ${update_loop_plain("interval_end")} // }}} %else: @@ -560,71 +578,56 @@ void ${name_prefix}_final_update( index_type group_base = interval_begin; scan_type prev_value = carry; // (A) - <%def name="update_loop_lookbehind(end, phase)"> - for(; group_base < ${end}; group_base += WG_SIZE) - { - index_type update_i = group_base+LID_0; - - // load a work group's worth of data - if (update_i < ${end}) - { - scan_type tmp = partial_scan_buffer[update_i]; - ldata[LID_0] = SCAN_EXPR(carry, tmp); - %if is_segmented: - l_segment_start_flags[LID_0] = g_segment_start_flags[update_i]; - %endif - } - - local_barrier(); + for(; group_base < interval_end; group_base += WG_SIZE) + { + index_type update_i = group_base+LID_0; - // find prev_value - if (LID_0 != 0) - prev_value = ldata[LID_0 - 1]; - /* - else - prev_value = carry (see (A)) OR last tail (see (B)); - */ + // load a work group's worth of data + if (update_i < interval_end) + { + scan_type tmp = partial_scan_buffer[update_i]; - if (update_i < ${end}) - { - %if is_segmented: - if (l_segment_start_flags[LID_0]) - prev_value = ${neutral}; - %endif + %if is_segmented: + if (update_i < first_seg_start_in_interval) + %endif + { tmp = SCAN_EXPR(carry, tmp); } - scan_type value = ldata[LID_0]; - OUTPUT_STMT(update_i, prev_value, value) - } + ldata[LID_0] = tmp; - if (LID_0 == 0) - prev_value = ldata[WG_SIZE - 1]; // (B) - - local_barrier(); + %if is_segmented: + l_segment_start_flags[LID_0] = g_segment_start_flags[update_i]; + %endif } - </%def> - // FIXME TAKE CARE OF BLOCK RESYNCING + local_barrier(); - <% update_loop = self.update_loop_lookbehind %> + // find prev_value + if (LID_0 != 0) + prev_value = ldata[LID_0 - 1]; + /* + else + prev_value = carry (see (A)) OR last tail (see (B)); + */ - // }}} - %endif - - %if is_segmented: - // {{{ update to the first intra-interval segment boundary + if (update_i < interval_end) + { + %if is_segmented: + if (l_segment_start_flags[LID_0]) + prev_value = ${neutral}; + %endif - const index_type first_seg_start_in_interval = - g_first_segment_start_in_interval[GID_0]; + scan_type value = ldata[LID_0]; + OUTPUT_STMT(update_i, prev_value, value) + } - index_type seg_end = min(first_seg_start_in_interval, interval_end); - ${update_loop("seg_end", phase="first_seg")} + if (LID_0 == 0) + prev_value = ldata[WG_SIZE - 1]; // (B) - carry = ${neutral}; + local_barrier(); + } // }}} %endif - - ${update_loop("interval_end", phase="remainder")} } """ @@ -674,24 +677,26 @@ _PREFIX_WORDS = set(""" l_ o_mod_k o_div_k l_segment_start_flags scan_value sum first_seg_start_in_interval g_segment_start_flags group_base seg_end + + LID_2 LID_1 LID_0 + LDIM_0 LDIM_1 LDIM_2 + GDIM_0 GDIM_1 GDIM_2 + GID_0 GID_1 GID_2 """.split()) _IGNORED_WORDS = set(""" typedef for endfor if void while endwhile endfor endif else const printf - None return bool n + None return bool n char set iteritems len setdefault - LID_2 LID_1 LID_0 - LDIM_0 LDIM_1 LDIM_2 - GDIM_0 GDIM_1 GDIM_2 - GID_0 GID_1 GID_2 GLOBAL_MEM LOCAL_MEM_ARG WITHIN_KERNEL LOCAL_MEM KERNEL REQD_WG_SIZE local_barrier CLK_LOCAL_MEM_FENCE OPENCL EXTENSION pragma __attribute__ __global __kernel __local get_local_size get_local_id cl_khr_fp64 reqd_work_group_size get_num_groups barrier get_group_id + cl_khr_byte_addressable_store _final_update _scan_intervals @@ -961,7 +966,8 @@ class GenericScanKernel(object): use_lookbehind_update=use_lookbehind_update, **self.code_variables)) - with open("update.cl", "wt") as f: f.write(final_update_src) + #with open("update.cl", "wt") as f: f.write(final_update_src) + final_update_prg = cl.Program(self.context, final_update_src).build(options) self.final_update_knl = getattr( final_update_prg, @@ -1011,7 +1017,7 @@ class GenericScanKernel(object): store_segment_start_flags=store_segment_start_flags, **self.code_variables)) - with open("scan-lev%d.cl" % (1 if is_first_level else 2), "wt") as f: f.write(scan_src) + #with open("scan-lev%d.cl" % (1 if is_first_level else 2), "wt") as f: f.write(scan_src) prg = cl.Program(self.context, scan_src).build(self.options) diff --git a/test/test_array.py b/test/test_array.py index 247ac996..18768689 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -807,7 +807,7 @@ def test_scan(ctx_factory): if 1 and not is_ok: print(summarize_error(dev_data.get(), desired_result, host_data)) - print n, is_ok + print(n, is_ok) assert is_ok from gc import collect collect() @@ -882,8 +882,8 @@ def test_segmented_scan(ctx_factory): dtype = np.int32 ctype = dtype_to_ctype(dtype) - for is_exclusive in [False, True]: - #for is_exclusive in [True, False]: + #for is_exclusive in [False, True]: + for is_exclusive in [True, False]: if is_exclusive: output_statement = "out[i] = prev_item" else: @@ -924,6 +924,7 @@ def test_segmented_scan(ctx_factory): for seg_boundaries in seg_boundaries_values: #print "BOUNDARIES", seg_boundaries + #print a seg_boundary_flags = np.zeros(n, dtype=np.uint8) seg_boundary_flags[seg_boundaries] = 1 @@ -940,11 +941,11 @@ def test_segmented_scan(ctx_factory): if is_exclusive: result_host[seg_start+1:seg_end] = np.cumsum( - result_host[seg_start:seg_end][:-1]) + a[seg_start:seg_end][:-1]) result_host[seg_start] = 0 else: result_host[seg_start:seg_end] = np.cumsum( - result_host[seg_start:seg_end]) + a[seg_start:seg_end]) #print "REF", result_host @@ -955,12 +956,13 @@ def test_segmented_scan(ctx_factory): is_correct = (result_dev.get() == result_host).all() if not is_correct: diff = result_dev.get() - result_host - print "RES-REF", diff + print("RES-REF", diff) print("ERRWHERE", np.where(diff)) print(n, list(seg_boundaries)) assert is_correct - print(n, "done") + + print("%d excl:%s done" % (n, is_exclusive)) -- GitLab