diff --git a/pyopencl/scan.py b/pyopencl/scan.py index ae2549984ec39bb7d4be701dd765e377f1712c91..6c144961794928f22987097e1f13b091e9966a8f 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -143,9 +143,25 @@ void ${name_prefix}_scan_intervals( %endif ) { - // padded in WG_SIZE to avoid bank conflicts // index K in first dimension used for carry storage - LOCAL_MEM scan_type ldata[K + 1][WG_SIZE + 1]; + %if scan_dtype.itemsize > 4 and scan_dtype.itemsize % 8 == 0: + // Avoid bank conflicts by adding a single 32-bit value to the size of + // the scan type. + struct __attribute__ ((__packed__)) wrapped_scan_type + { + scan_type value; + int dummy; + }; + LOCAL_MEM struct wrapped_scan_type ldata[K + 1][WG_SIZE + 1]; + %else: + struct wrapped_scan_type + { + scan_type value; + }; + + // padded in WG_SIZE to avoid bank conflicts + LOCAL_MEM struct wrapped_scan_type ldata[K + 1][WG_SIZE]; + %endif %if is_segmented: LOCAL_MEM char l_segment_start_flags[K][WG_SIZE]; @@ -253,7 +269,7 @@ void ${name_prefix}_scan_intervals( const index_type o_mod_k = offset % K; const index_type o_div_k = offset / K; - ldata[o_mod_k][offset / K] = scan_value; + ldata[o_mod_k][offset / K].value = scan_value; %if is_segmented: bool is_seg_start = IS_SEG_START(read_i, scan_value); @@ -281,7 +297,7 @@ void ${name_prefix}_scan_intervals( if (LID_0 == 0 && unit_base != interval_begin) { - ldata[0][0] = SCAN_EXPR(ldata[K][WG_SIZE - 1], ldata[0][0], + ldata[0][0].value = SCAN_EXPR(ldata[K][WG_SIZE - 1].value, ldata[0][0].value, %if is_segmented: (l_segment_start_flags[0][0]) %else: @@ -298,7 +314,7 @@ void ${name_prefix}_scan_intervals( // {{{ scan along k (sequentially in each work item) - scan_type sum = ldata[0][LID_0]; + scan_type sum = ldata[0][LID_0].value; %if is_tail: const index_type offset_end = interval_end - unit_base; @@ -310,7 +326,7 @@ void ${name_prefix}_scan_intervals( if (K * LID_0 + k < offset_end) %endif { - scan_type tmp = ldata[k][LID_0]; + scan_type tmp = ldata[k][LID_0].value; index_type seq_i = unit_base + K*LID_0 + k; %if is_segmented: @@ -330,7 +346,7 @@ void ${name_prefix}_scan_intervals( %endif ); - ldata[k][LID_0] = sum; + ldata[k][LID_0].value = sum; } } @@ -339,7 +355,7 @@ void ${name_prefix}_scan_intervals( // }}} // store carry in out-of-bounds (padding) array entry (index K) in the K direction - ldata[K][LID_0] = sum; + ldata[K][LID_0].value = sum; %if is_segmented: l_first_segment_start_in_subtree[LID_0] = first_segment_start_in_k_group; @@ -361,7 +377,7 @@ void ${name_prefix}_scan_intervals( // across k groups, along local id // (uses out-of-bounds k=K array entry for storage) - scan_type val = ldata[K][LID_0]; + scan_type val = ldata[K][LID_0].value; <% scan_offset = 1 %> @@ -370,7 +386,7 @@ void ${name_prefix}_scan_intervals( if (LID_0 >= ${scan_offset}) { - scan_type tmp = ldata[K][LID_0 - ${scan_offset}]; + scan_type tmp = ldata[K][LID_0 - ${scan_offset}].value; % if is_tail: if (K*LID_0 < offset_end) % endif @@ -409,7 +425,7 @@ void ${name_prefix}_scan_intervals( // {{{ writes to local allowed, reads from local not allowed - ldata[K][LID_0] = val; + ldata[K][LID_0].value = val; %if is_segmented: l_first_segment_start_in_subtree[LID_0] = first_segment_start_in_subtree; @@ -445,7 +461,7 @@ void ${name_prefix}_scan_intervals( if (LID_0 > 0) { - sum = ldata[K][LID_0 - 1]; + sum = ldata[K][LID_0 - 1].value; for(index_type k = 0; k < K; k++) { @@ -453,8 +469,8 @@ void ${name_prefix}_scan_intervals( if (K * LID_0 + k < offset_end) %endif { - scan_type tmp = ldata[k][LID_0]; - ldata[k][LID_0] = SCAN_EXPR(sum, tmp, + scan_type tmp = ldata[k][LID_0].value; + ldata[k][LID_0].value = SCAN_EXPR(sum, tmp, %if is_segmented: (unit_base + K * LID_0 + k >= first_segment_start_in_k_group) @@ -517,7 +533,7 @@ void ${name_prefix}_scan_intervals( index_type remainder = linear_index - linear_scan_data_idx * scan_types_per_int; __local int *src = (__local int *) &( - ldata[linear_scan_data_idx % K][linear_scan_data_idx / K]); + ldata[linear_scan_data_idx % K][linear_scan_data_idx / K].value); dest[linear_index] = src[remainder]; } @@ -1190,9 +1206,9 @@ class GenericScanKernel(_GenericScanKernelBase): store_segment_start_flags, k_group_size): scalar_arg_dtypes = _get_scalar_arg_dtypes(arguments) - # Thrust says that 128 is big enough for GT200 + # Empirically found on Nv hardware: no need to be bigger than this size wg_size = _round_down_to_power_of_2( - min(max_wg_size, 128)) + min(max_wg_size, 256)) scan_tpl = _make_template(SCAN_INTERVALS_SOURCE) scan_src = str(scan_tpl.render(