diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 75e4a83e0dd904859b9d2633b9049dd5356598ef..7562f485cfb2809ad9dc1fc4cdbec45c269c97d8 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -36,11 +36,14 @@ import pyopencl.array as cl_array from pyopencl.tools import dtype_to_ctype, bitlog2 import pyopencl._mymako as mako from pyopencl._cluda import CLUDA_PREAMBLE +from pyopencl.tools import context_dependent_memoize +# {{{ preamble + SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL// #define WG_SIZE ${wg_size} @@ -56,18 +59,59 @@ and thus index-based side computations cannot be meaningful. */ ${preamble} -typedef ${scan_type} scan_type; +typedef ${scan_ctype} scan_type; typedef ${index_ctype} index_type; #define NO_SEG_BOUNDARY ${index_type_max} - """ - - +# }}} # {{{ main scan code -SCAN_INTERVALS_SOURCE = mako.template.Template(SHARED_PREAMBLE + """//CL// +# Algorithm: Each work group is responsible for one contiguous +# 'interval'. There are just enough intervals to fill all compute +# units. Intervals are split into 'units'. A unit is what gets +# worked on in parallel by one work group. +# +# in index space: +# interval > unit > local-parallel > k-group +# +# (Note that there is also a transpose in here: The data is read +# with local ids along linear index order.) +# +# Each unit has two axes--the local-id axis and the k axis. +# +# unit 0: +# | | | | | | | | | | ----> lid +# | | | | | | | | | | +# | | | | | | | | | | +# | | | | | | | | | | +# | | | | | | | | | | +# +# | +# v k (fastest-moving in linear index) +# +# unit 1: +# | | | | | | | | | | ----> lid +# | | | | | | | | | | +# | | | | | | | | | | +# | | | | | | | | | | +# | | | | | | | | | | +# +# | +# v k (fastest-moving in linear index) +# +# ... +# +# At a device-global level, this is a three-phase algorithm, in +# which first each interval does its local scan, then a scan +# across intervals exchanges data globally, and the final update +# adds the exchanged sums to each interval. +# +# Exclusive scan is realized by performing a right-shift inside +# the final update. + +SCAN_INTERVALS_SOURCE = SHARED_PREAMBLE + """//CL// #define K ${k_group_size} @@ -119,48 +163,6 @@ void ${name_prefix}_scan_intervals( %endif { - // Algorithm: Each work group is responsible for one contiguous - // 'interval'. There are just enough intervals to fill all compute - // units. Intervals are split into 'units'. A unit is what gets - // worked on in parallel by one work group. - - // in index space: - // interval > unit > local-parallel > k-group - - // (Note that there is also a transpose in here: The data is read - // with local ids along linear index order.) - - // Each unit has two axes--the local-id axis and the k axis. - // - // unit 0: - // | | | | | | | | | | ----> lid - // | | | | | | | | | | - // | | | | | | | | | | - // | | | | | | | | | | - // | | | | | | | | | | - // - // | - // v k (fastest-moving in linear index) - - // unit 1: - // | | | | | | | | | | ----> lid - // | | | | | | | | | | - // | | | | | | | | | | - // | | | | | | | | | | - // | | | | | | | | | | - // - // | - // v k (fastest-moving in linear index) - // - // ... - - // At a device-global level, this is a three-phase algorithm, in - // which first each interval does its local scan, then a scan - // across intervals exchanges data globally, and the final update - // adds the exchanged sums to each interval. - - // Exclusive scan is realized by performing a right-shift inside - // the final update. // {{{ read a unit's worth of data from global @@ -168,13 +170,13 @@ void ${name_prefix}_scan_intervals( { const index_type offset = k*WG_SIZE + LID_0; - const index_type i = unit_base + offset; + const index_type read_i = unit_base + offset; %if is_tail: - if (i < interval_end) + if (read_i < interval_end) %endif { - ldata[offset % K][offset / K] = INPUT_EXPR(i); + ldata[offset % K][offset / K] = INPUT_EXPR(read_i); } } @@ -219,12 +221,12 @@ void ${name_prefix}_scan_intervals( %endif { scan_type tmp = ldata[k][LID_0]; - index_type i = unit_base + K*LID_0 + k; + index_type seq_i = unit_base + K*LID_0 + k; %if is_segmented: - if (IS_SEG_START(i, tmp) + if (IS_SEG_START(seq_i, tmp) { - first_segment_start_in_k_group = i; + first_segment_start_in_k_group = seq_i; sum = tmp; } else @@ -250,9 +252,9 @@ void ${name_prefix}_scan_intervals( // This tree-based scan works as follows: // - Each work item adds the previous item to its current state - // - barrier sync + // - barrier // - Each work item adds in the item from two positions to the left - // - barrier sync + // - barrier // - Each work item adds in the item from four positions to the left // ... // At the end, each item has summed all prior items. @@ -392,15 +394,17 @@ void ${name_prefix}_scan_intervals( %endif } } -""", strict_undefined=True, disable_unicode=True) +""" # }}} -# {{{ inclusive update +# {{{ local update + +# used for inclusive scan -INCLUSIVE_UPDATE_SOURCE = mako.template.Template(SHARED_PREAMBLE + """//CL// +LOCAL_UPDATE_SOURCE = SHARED_PREAMBLE + r"""//CL// -#define OUTPUT_STMT(i, a) ${output_statement} +#define OUTPUT_STMT(i, prev_item, item) { ${output_statement}; } KERNEL REQD_WG_SIZE(WG_SIZE, 1, 1) @@ -415,8 +419,10 @@ void ${name_prefix}_final_update( %endif ) { - if (GID_0 == 0) - return; + %if neutral is None: + if (GID_0 == 0) + return; + %endif const index_type interval_begin = interval_size * GID_0; const index_type interval_end = min(interval_begin + interval_size, N); @@ -426,31 +432,40 @@ void ${name_prefix}_final_update( %endif // value to add to this segment - scan_type prev_group_sum = interval_results[GID_0 - 1]; + scan_type prev_group_sum; + if (GID_0 == 0) + prev_group_sum = ${neutral}; + else + prev_group_sum = interval_results[GID_0 - 1]; for(index_type unit_base = interval_begin; unit_base < interval_end; unit_base += WG_SIZE) { - const index_type i = unit_base + LID_0; + const index_type update_i = unit_base + LID_0; - if(i < interval_end) + if(update_i < interval_end) { - scan_type val = partial_scan_buffer[i]; - scan_type value = SCAN_EXPR(prev_group_sum, val); - OUTPUT_STMT(i, value) + scan_type partial_val = partial_scan_buffer[update_i]; + scan_type value = SCAN_EXPR(prev_group_sum, partial_val); + + // printf("i: %d pgs: %d pv: %d val: %d\n", update_i, prev_group_sum, partial_val, value); + + OUTPUT_STMT(update_i, prev_item_unavailable_with_local_update, value); } } } -""", strict_undefined=True, disable_unicode=True) +""" # }}} -# {{{ exclusive update +# {{{ lookbehind update + +# used for exclusive scan or output_statements that request look-behind -EXCLUSIVE_UPDATE_SOURCE = mako.template.Template(SHARED_PREAMBLE + """//CL// +LOOKBEHIND_UPDATE_SOURCE = SHARED_PREAMBLE + """//CL// -#define OUTPUT_STMT(i, a) ${output_statement} +#define OUTPUT_STMT(i, prev_item, item) { ${output_statement}; } KERNEL REQD_WG_SIZE(WG_SIZE, 1, 1) @@ -475,18 +490,18 @@ void ${name_prefix}_final_update( if(GID_0 != 0) carry = interval_results[GID_0 - 1]; - scan_type value = carry; // (A) + scan_type prev_value = carry; // (A) for (index_type unit_base = interval_begin; unit_base < interval_end; unit_base += WG_SIZE) { - const index_type i = unit_base + LID_0; + const index_type update_i = unit_base + LID_0; // load a work group's worth of data - if (i < interval_end) + if (update_i < interval_end) { - scan_type tmp = partial_scan_buffer[i]; + scan_type tmp = partial_scan_buffer[update_i]; ldata[LID_0] = SCAN_EXPR(carry, tmp); } @@ -494,36 +509,42 @@ void ${name_prefix}_final_update( // perform right shift if (LID_0 != 0) - value = ldata[LID_0 - 1]; + prev_value = ldata[LID_0 - 1]; /* - else - value = carry (see (A)) OR last tail (see (B)); + else + prev_value = carry (see (A)) OR last tail (see (B)); */ %if is_segmented: { - scan_type scan_item_at_i = INPUT_EXPR(i) - bool is_seg_start = IS_SEG_START(i, scan_item_at_i); + scan_type scan_item_at_i = INPUT_EXPR(update_i) + bool is_seg_start = IS_SEG_START(update_i, scan_item_at_i); if (is_seg_start) - value = ${neutral}; + prev_value = ${neutral}; } %endif - if (i < interval_end) + if (update_i < interval_end) { - OUTPUT_STMT(i, value) + scan_type value = ldata[LID_0]; + + OUTPUT_STMT(update_i, prev_value, value) } if(LID_0 == 0) - value = ldata[WG_SIZE - 1]; // (B) + prev_value = ldata[WG_SIZE - 1]; // (B) local_barrier(); } } -""", strict_undefined=True, disable_unicode=True) +""" # }}} +# {{{ driver + +# {{{ helpers + def _round_down_to_power_of_2(val): result = 2**bitlog2(val) if result > val: @@ -532,14 +553,6 @@ def _round_down_to_power_of_2(val): assert result <= val return result - - - - -# {{{ driver - -# {{{ helpers - def _parse_args(arguments): from pyopencl.tools import parse_c_arg return [parse_c_arg(arg) for arg in arguments.split(",")] @@ -556,8 +569,83 @@ def _get_scalar_arg_dtypes(arg_types): return result +_PREFIX_WORDS = set(""" + ldata partial_scan_buffer global scan_offset + segment_start_in_k_group carry + g_first_segment_start_in_interval IS_SEG_START tmp Z + val l_first_segment_start_in_k_group unit_size + index_type interval_begin interval_size offset_end K + SCAN_EXPR do_update NO_SEG_BOUNDARY WG_SIZE + first_segment_start_in_k_group scan_type + segment_start_in_subtree offset interval_results interval_end + first_segment_start_in_subtree unit_base + first_segment_start_in_interval k INPUT_EXPR + prev_group_sum prev pv this add value n partial_val pgs OUTPUT_STMT + is_seg_start update_i scan_item_at_i seq_i read_i + """.split()) + +_IGNORED_WORDS = set(""" + typedef for endfor if void while endwhile endfor endif else const printf + None return bool + LID_2 LID_1 LID_0 + LDIM_0 LDIM_1 LDIM_2 + GDIM_0 GDIM_1 GDIM_2 + GID_0 GID_1 GID_2 + GLOBAL_MEM LOCAL_MEM_ARG WITHIN_KERNEL LOCAL_MEM KERNEL REQD_WG_SIZE + local_barrier + CLK_LOCAL_MEM_FENCE OPENCL EXTENSION + pragma __attribute__ __global __kernel __local + get_local_size get_local_id cl_khr_fp64 reqd_work_group_size + get_num_groups barrier get_group_id + + _final_update _scan_intervals + + positions all padded integer its previous write based true writes 0 + has local worth scan_expr to read cannot not X items False bank + four beginning follows applicable sum item min each indices works side + scanning right summed relative used id out index avoid current state + boundary True across be This reads groups along Otherwise undetermined + store of times prior s update first regardless Each number because + array unit from segment conflicts two parallel 2 empty define direction + CL padding work tree bounds values and adds + scan is allowed thus it an as enable at in occur sequentially end no + storage data 1 largest may representable uses entry Y meaningful + computations interval At the left dimension know d + A load B group perform shift tail see last OR + + is_tail is_first_level input_expr argument_signature preamble + double_support neutral output_statement index_type_max + k_group_size name_prefix is_segmented index_ctype scan_ctype + wg_size is_i_segment_start_expr + + a b prev_item i prev_item_unavailable_with_local_update prev_value + N + """.split()) + +def _make_template(s): + leftovers = set() + + def replace_id(match): + # avoid name clashes with user code by adding 'psc_' prefix to + # identifiers. + + word = match.group(1) + if word in _IGNORED_WORDS: + return word + elif word in _PREFIX_WORDS: + return "psc_"+word + else: + leftovers.add(word) + return word + + import re + s = re.sub(r"\b([a-zA-Z0-9_]+)\b", replace_id, s) + if leftovers: + from warnings import warn + warn("leftover words in identifier prefixing:" + " ".join(leftovers)) + return mako.template.Template(s, strict_undefined=True, disable_unicode=True) from pytools import Record class _ScanKernelInfo(Record): @@ -565,10 +653,10 @@ class _ScanKernelInfo(Record): # }}} -class _GenericScanKernelBase(object): +class GenericScanKernel(object): def __init__(self, ctx, dtype, - arguments, scan_expr, input_expr, output_statement, - neutral=None, is_i_segment_start_expr=None, + arguments, input_expr, scan_expr, neutral, output_statement, + is_i_segment_start_expr=None, partial_scan_buffer_name=None, name_prefix="scan", options=[], preamble="", devices=None): """ @@ -585,8 +673,11 @@ class _GenericScanKernelBase(object): to each array entry when scan first touches it. *arguments* must be given if *input_expr* is given. :arg output_statement: a C statement that writes - the output of the scan. It has access to the scan result as `a` - and the current index as `i`. + the output of the scan. It has access to the scan result as *item*, + the preceding scan result item as *prev_item*, and the current index + as *i*. *prev_item* is unavailable when using exclusive scan. + *prev_item* in a segmented scan will be the neutral element + at a segment boundary, not the immediately preceding item. :arg is_i_segment_start_expr: If given, makes the scan a segmented scan. Has access to the current index `i` and the input element as `a` and returns a bool. If it returns true, then previous @@ -594,6 +685,9 @@ class _GenericScanKernelBase(object): The first array in the argument list determines the size of the index space over which the scan is carried out. + + All code fragments further have access to N, the number of elements + being processed in the scan. """ if isinstance(self, ExclusiveScanKernel) and neutral is None: @@ -601,7 +695,13 @@ class _GenericScanKernelBase(object): self.context = ctx dtype = self.dtype = np.dtype(dtype) - self.neutral = neutral + + if neutral is None: + from warnings import warn + warn("not specifying 'neutral' is deprecated and will lead to " + "wrong results if your scan is not in-place or your " + "'output_statement' otherwise does something non-trivial", + stacklevel=2) self.index_dtype = np.dtype(np.uint32) @@ -618,8 +718,12 @@ class _GenericScanKernelBase(object): if isinstance(arg, VectorArg)][0] self.is_segmented = is_i_segment_start_expr is not None + if self.is_segmented: + is_i_segment_start_expr = is_i_segment_start_expr.replace("\n", " ") - if self.is_segmented and self.is_exclusive: + use_lookbehind_update = "prev_item" in output_statement + + if self.is_segmented and use_lookbehind_update: # The final update in segmented exclusive scan must be able to # reconstruct where the segment boundaries were, and therefore # can't overwrite any of the input. @@ -642,10 +746,10 @@ class _GenericScanKernelBase(object): name_prefix=name_prefix, index_ctype=dtype_to_ctype(self.index_dtype), index_type_max=str(np.iinfo(self.index_dtype).max), - scan_type=dtype_to_ctype(dtype), + scan_ctype=dtype_to_ctype(dtype), is_segmented=self.is_segmented, - scan_expr=scan_expr, - neutral=neutral, + scan_expr=scan_expr.replace("\n", " "), + neutral=neutral.replace("\n", " "), double_support=all( has_double_support(dev) for dev in devices), ) @@ -724,12 +828,18 @@ class _GenericScanKernelBase(object): self.update_wg_size = min(max_scan_wg_size, 256) - final_update_src = str(self.final_update_tp.render( + if use_lookbehind_update: + update_src = LOOKBEHIND_UPDATE_SOURCE + else: + update_src = LOCAL_UPDATE_SOURCE + + final_update_tpl = _make_template(update_src) + final_update_src = str(final_update_tpl.render( wg_size=self.update_wg_size, - output_statement=output_statement, - argument_signature=arguments, + output_statement=output_statement.replace("\n", " "), + argument_signature=arguments.replace("\n", " "), is_i_segment_start_expr=is_i_segment_start_expr, - input_expr=input_expr, + input_expr=input_expr.replace("\n", " "), **self.code_variables)) final_update_prg = cl.Program(self.context, final_update_src).build(options) @@ -759,11 +869,12 @@ class _GenericScanKernelBase(object): else: k_group_size = 8 - scan_intervals_src = str(SCAN_INTERVALS_SOURCE.render( + scan_tpl = _make_template(SCAN_INTERVALS_SOURCE) + scan_intervals_src = str(scan_tpl.render( wg_size=wg_size, input_expr=input_expr, k_group_size=k_group_size, - argument_signature=arguments, + argument_signature=arguments.replace("\n", " "), is_i_segment_start_expr=is_i_segment_start_expr, is_first_level=is_first_level, **self.code_variables)) @@ -830,7 +941,7 @@ class _GenericScanKernelBase(object): interval_results = allocator(self.dtype.itemsize*num_intervals) if self.partial_scan_buffer_idx is None: - partial_scan_buffer = allocator(n) + partial_scan_buffer = allocator(n*self.dtype.itemsize) else: partial_scan_buffer = data_args[self.partial_scan_buffer_idx] @@ -846,6 +957,11 @@ class _GenericScanKernelBase(object): queue, (num_intervals,), (l1_info.wg_size,), *scan1_args, **dict(g_times_l=True)) + if 0: + psb_host = np.empty(n, self.dtype) + cl.enqueue_copy(queue, psb_host, partial_scan_buffer) + print "PSB", psb_host + # }}} # {{{ second level inclusive scan of per-interval results @@ -878,29 +994,21 @@ class _GenericScanKernelBase(object): # }}} +# {{{ compatibility interface - -class GenericInclusiveScanKernel(_GenericScanKernelBase): - final_update_tp = INCLUSIVE_UPDATE_SOURCE - is_exclusive = False - -class GenericExclusiveScanKernel(_GenericScanKernelBase): - final_update_tp = EXCLUSIVE_UPDATE_SOURCE - is_exclusive = True - -class _ScanKernelBase(_GenericScanKernelBase): +class _ScanKernelBase(GenericScanKernel): def __init__(self, ctx, dtype, scan_expr, neutral=None, name_prefix="scan", options=[], preamble="", devices=None): scan_ctype = dtype_to_ctype(dtype) - _GenericScanKernelBase.__init__(self, + GenericScanKernel.__init__(self, ctx, dtype, arguments="__global %s *input_ary, __global %s *output_ary" % ( scan_ctype, scan_ctype), - scan_expr=scan_expr, input_expr="input_ary[i]", - output_statement="output_ary[i] = a;", + scan_expr=scan_expr, neutral=neutral, + output_statement=self.ary_output_statement, partial_scan_buffer_name="output_ary", options=options, preamble=preamble, devices=devices) @@ -926,17 +1034,56 @@ class _ScanKernelBase(_GenericScanKernelBase): if not n: return output_ary - _GenericScanKernelBase.__call__(self, + GenericScanKernel.__call__(self, input_ary, output_ary, allocator=allocator, queue=queue) return output_ary class InclusiveScanKernel(_ScanKernelBase): - final_update_tp = INCLUSIVE_UPDATE_SOURCE - is_exclusive = False + ary_output_statement = "output_ary[i] = item;" class ExclusiveScanKernel(_ScanKernelBase): - final_update_tp = EXCLUSIVE_UPDATE_SOURCE - is_exclusive = True + ary_output_statement = "output_ary[i] = prev_item;" + +# }}} + +# {{{ higher-level trickery + +@context_dependent_memoize +def get_copy_if_kernel(ctx, dtype, predicate, scan_dtype): + ctype = dtype_to_ctype(dtype) + return GenericScanKernel( + ctx, dtype, + arguments="__global %s *ary, __global %s *out, __global unsigned long *count" % (ctype, ctype), + input_expr="(%s) ? 1 : 0" % predicate, + scan_expr="a+b", neutral="0", + output_statement=""" + if (prev_item != item) out[item-1] = ary[i]; + if (i+1 == N) *count = item; + """ + ) + +def copy_if(ary, predicate, queue=None): + if len(ary) > np.iinfo(np.uint32): + scan_dtype = np.uint64 + else: + scan_dtype = np.uint32 + + knl = get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype) + out = cl_array.empty_like(ary) + count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64) + knl(ary, out, count, queue=queue) + return out, count + +def remove_if(array, predicate, **kwargs): + pass + +def partition(array, predicate): + pass + +def unique_by_key(array, key="", **kwargs): + pass + +# }}} # vim: filetype=pyopencl:fdm=marker diff --git a/test/test_array.py b/test/test_array.py index 7b8bdf0b45d0f5f8d979e60fcdbe2ae14d223f86..c8ee6a003e27232a7f8b2ecd0158daf61c94bdd9 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -741,6 +741,17 @@ def summarize_error(obtained, desired, orig, thresh=1e-5): return " ".join(entries) +scan_test_counts = [ + 10, + 2 ** 10 - 5, + 2 ** 10, + 2 ** 10 + 5, + 2 ** 20 - 2 ** 18, + 2 ** 20 - 2 ** 18 + 5, + 2 ** 20 + 1, + 2 ** 20, 2 ** 24 + ] + @pytools.test.mark_test.opencl def test_scan(ctx_factory): context = ctx_factory() @@ -755,16 +766,7 @@ def test_scan(ctx_factory): ]: knl = cls(context, dtype, "a+b", "0") - for n in [ - 10, - 2 ** 10 - 5, - 2 ** 10, - 2 ** 10 + 5, - 2 ** 20 - 2 ** 18, - 2 ** 20 - 2 ** 18 + 5, - 2 ** 20 + 1, - 2 ** 20, 2 ** 24 - ]: + for n in scan_test_counts: host_data = np.random.randint(0, 10, n).astype(dtype) dev_data = cl_array.to_device(queue, host_data) @@ -778,7 +780,7 @@ def test_scan(ctx_factory): desired_result -= host_data is_ok = (dev_data.get() == desired_result).all() - if 0 and not is_ok: + if 1 and not is_ok: print(summarize_error(dev_data.get(), desired_result, host_data)) print n, is_ok @@ -786,6 +788,22 @@ def test_scan(ctx_factory): from gc import collect collect() +@pytools.test.mark_test.opencl +def test_copy_if(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + from pyopencl.clrandom import rand as clrand + for n in scan_test_counts: + a_dev = clrand(queue, (n,), dtype=np.int32, a=0, b=1000) + a = a_dev.get() + + from pyopencl.scan import copy_if + + selected = a[a>300] + selected_dev, count_dev = copy_if(a_dev, "ary[i] > 300") + + assert (selected_dev.get()[:count_dev.get()] == selected).all() @pytools.test.mark_test.opencl def test_stride_preservation(ctx_factory):