diff --git a/pyopencl/scan.py b/pyopencl/scan.py index c61523f029d0ac62a13c3d9ed4937ac084996351..5ac759f5cfa45b08aaa28b52bac3a24448c9ee98 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -500,8 +500,6 @@ void ${name_prefix}_scan_intervals( UPDATE_SOURCE = SHARED_PREAMBLE + r"""//CL// -#define OUTPUT_STMT(i, prev_item, item) { ${output_statement}; } - KERNEL REQD_WG_SIZE(WG_SIZE, 1, 1) void ${name_prefix}_final_update( @@ -548,9 +546,10 @@ void ${name_prefix}_final_update( for(; update_i < ${end}; update_i += WG_SIZE) { scan_type partial_val = partial_scan_buffer[update_i]; - scan_type value = SCAN_EXPR(carry, partial_val); + scan_type item = SCAN_EXPR(carry, partial_val); + index_type i = update_i; - OUTPUT_STMT(update_i, prev_item_unavailable_with_local_update, value); + { ${output_statement}; } } </%def> @@ -575,7 +574,7 @@ void ${name_prefix}_final_update( // and there are lots of local ifs. index_type group_base = interval_begin; - scan_type prev_value = carry; // (A) + scan_type prev_item = carry; // (A) for(; group_base < interval_end; group_base += WG_SIZE) { @@ -600,27 +599,28 @@ void ${name_prefix}_final_update( local_barrier(); - // find prev_value + // find prev_item if (LID_0 != 0) - prev_value = ldata[LID_0 - 1]; + prev_item = ldata[LID_0 - 1]; /* else - prev_value = carry (see (A)) OR last tail (see (B)); + prev_item = carry (see (A)) OR last tail (see (B)); */ if (update_i < interval_end) { %if is_segmented: if (l_segment_start_flags[LID_0]) - prev_value = ${neutral}; + prev_item = ${neutral}; %endif - scan_type value = ldata[LID_0]; - OUTPUT_STMT(update_i, prev_value, value) + scan_type item = ldata[LID_0]; + index_type i = update_i; + { ${output_statement}; } } if (LID_0 == 0) - prev_value = ldata[WG_SIZE - 1]; // (B) + prev_item = ldata[WG_SIZE - 1]; // (B) local_barrier(); } @@ -681,7 +681,7 @@ _PREFIX_WORDS = set(""" segment_start_in_subtree offset interval_results interval_end first_segment_start_in_subtree unit_base first_segment_start_in_interval k INPUT_EXPR - prev_group_sum prev pv value partial_val pgs OUTPUT_STMT + prev_group_sum prev pv value partial_val pgs is_seg_start update_i scan_item_at_i seq_i read_i l_ o_mod_k o_div_k l_segment_start_flags scan_value sum first_seg_start_in_interval g_segment_start_flags @@ -1196,8 +1196,6 @@ class GenericScanKernel(_GenericScanKernelBase): DEBUG_SCAN_TEMPLATE = SHARED_PREAMBLE + """//CL// -#define OUTPUT_STMT(i, prev_item, item) { ${output_statement}; } - KERNEL REQD_WG_SIZE(1, 1, 1) void ${name_prefix}_debug_scan( @@ -1232,7 +1230,7 @@ void ${name_prefix}_debug_scan( item = SCAN_EXPR(prev_item, my_val); { - OUTPUT_STMT(i, prev_item, item); + ${output_statement}; } } } @@ -1242,8 +1240,8 @@ class GenericDebugScanKernel(_GenericScanKernelBase): def finish_setup(self): scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE) scan_src = str(scan_tpl.render( - output_statement=_process_code_for_macro(self.output_statement), - argument_signature=_process_code_for_macro(self.arguments), + output_statement=self.output_statement, + argument_signature=", ".join(arg.declarator() for arg in self.parsed_args), is_segment_start_expr=self.is_segment_start_expr, input_expr=_process_code_for_macro(self.input_expr), input_fetch_exprs=self.input_fetch_exprs,