diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 0db3a2187563e3df31d1eacb659711baebcedb10..35a23cfa3c2de6154188a491cd6a7849988703ee 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -138,12 +138,29 @@ void ${name_prefix}_scan_intervals( // index K in first dimension used for carry storage LOCAL_MEM scan_type ldata[K + 1][WG_SIZE + 1]; + // {{{ set up local data for input_fetch_exprs if any of them are stenciled + + <% + fetch_expr_offsets = {} + for name, arg_name, ife_offset in input_fetch_exprs: + fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset) + + local_fetch_expr_args = set( + arg_name + for arg_name, ife_offsets in fetch_expr_offsets.iteritems() + if -1 in ife_offsets or len(ife_offsets) > 1) + %> + + %for arg_name in local_fetch_expr_args: + LOCAL_MEM ${arg_ctypes[arg_name]} l_${arg_name}[WG_SIZE*K]; + %endfor + + // }}} + %if is_segmented: index_type first_segment_start_in_interval = NO_SEG_BOUNDARY; LOCAL_MEM index_type l_first_segment_start_in_k_group[WG_SIZE]; index_type first_segment_start_in_k_group; - - if (LID_0 == 0) %endif @@ -164,25 +181,67 @@ void ${name_prefix}_scan_intervals( { + // {{{ carry out input_fetch_exprs + // (if there are ones that need to be fetched into local) + + %if local_fetch_expr_args: + for(index_type k = 0; k < K; k++) + { + const index_type offset = k*WG_SIZE + LID_0; + const index_type read_i = unit_base + offset; + + %for arg_name in local_fetch_expr_args: + %if is_tail: + if (read_i < interval_end) + %endif + { + l_${arg_name}[offset] = ${arg_name}[read_i]; + } + %endfor + } + + local_barrier(); + %endif + + // }}} + // {{{ read a unit's worth of data from global for(index_type k = 0; k < K; k++) { const index_type offset = k*WG_SIZE + LID_0; - const index_type read_i = unit_base + offset; %if is_tail: if (read_i < interval_end) %endif { + %for name, arg_name, ife_offset in input_fetch_exprs: + ${arg_ctypes[arg_name]} ${name}; + + %if arg_name in local_fetch_expr_args: + if (offset + ${ife_offset} >= 0) + ${name} = l_${arg_name}[offset + ${ife_offset}]; + else if (read_i + ${ife_offset} >= 0) + ${name} = ${arg_name}[read_i + ${ife_offset}]; + /* + else + if out of bounds, name is left undefined */ + + %else: + // ${arg_name} gets fetched directly from global + ${name} = ${arg_name}[read_i]; + + %endif + %endfor + ldata[offset % K][offset / K] = INPUT_EXPR(read_i); } } // }}} - // {{{ carry in from previous unit, if applicable. + // {{{ carry in from previous unit, if applicable %if is_segmented: if (LID_0 == 0 && unit_base != interval_begin) @@ -461,7 +520,8 @@ void ${name_prefix}_final_update( # {{{ lookbehind update -# used for exclusive scan or output_statements that request look-behind +# used for exclusive scan or output_statements that request look-behind, i.e. +# access to the preceding item. LOOKBEHIND_UPDATE_SOURCE = SHARED_PREAMBLE + """//CL// @@ -580,13 +640,17 @@ _PREFIX_WORDS = set(""" segment_start_in_subtree offset interval_results interval_end first_segment_start_in_subtree unit_base first_segment_start_in_interval k INPUT_EXPR - prev_group_sum prev pv this add value n partial_val pgs OUTPUT_STMT + prev_group_sum prev pv value n partial_val pgs OUTPUT_STMT is_seg_start update_i scan_item_at_i seq_i read_i + l_ """.split()) _IGNORED_WORDS = set(""" typedef for endfor if void while endwhile endfor endif else const printf None return bool + + set iteritems len setdefault + LID_2 LID_1 LID_0 LDIM_0 LDIM_1 LDIM_2 GDIM_0 GDIM_1 GDIM_2 @@ -612,11 +676,16 @@ _IGNORED_WORDS = set(""" storage data 1 largest may representable uses entry Y meaningful computations interval At the left dimension know d A load B group perform shift tail see last OR + this add fetched into are directly need + gets them stenciled that undefined + there up any ones or name is_tail is_first_level input_expr argument_signature preamble double_support neutral output_statement index_type_max k_group_size name_prefix is_segmented index_ctype scan_ctype - wg_size is_i_segment_start_expr + wg_size is_i_segment_start_expr fetch_expr_offsets + arg_ctypes ife_offsets input_fetch_exprs + ife_offset arg_name local_fetch_expr_args a b prev_item i prev_item_unavailable_with_local_update prev_value N @@ -643,7 +712,7 @@ def _make_template(s): if leftovers: from warnings import warn - warn("leftover words in identifier prefixing:" + " ".join(leftovers)) + warn("leftover words in identifier prefixing: " + " ".join(leftovers)) return mako.template.Template(s, strict_undefined=True, disable_unicode=True) @@ -656,8 +725,9 @@ class _ScanKernelInfo(Record): class GenericScanKernel(object): def __init__(self, ctx, dtype, arguments, input_expr, scan_expr, neutral, output_statement, - is_i_segment_start_expr=None, + is_i_segment_start_expr=None, input_fetch_exprs=[], partial_scan_buffer_name=None, + index_dtype=np.int32, name_prefix="scan", options=[], preamble="", devices=None): """ :arg ctx: a :class:`pyopencl.Context` within which the code @@ -685,6 +755,13 @@ class GenericScanKernel(object): scan. Has access to the current index `i` and the input element as `a` and returns a bool. If it returns true, then previous sums will not spill over into the item with index i. + :arg input_fetch_exprs: a list of tuples *(NAME, ARG_NAME, OFFSET)*. + An entry here has the effect of doing the equivalent of the following + before input_expr:: + + typeof(ARG_NAME) NAME = ARG_NAME[i+OFFSET]; + + OFFSET is allowed to be 0 or -1. The first array in the argument list determines the size of the index space over which the scan is carried out. @@ -706,7 +783,9 @@ class GenericScanKernel(object): "'output_statement' otherwise does something non-trivial", stacklevel=2) - self.index_dtype = np.dtype(np.uint32) + self.index_dtype = np.dtype(index_dtype) + if np.iinfo(self.index_dtype).min >= 0: + raise TypeError("index_dtype must be signed") if devices is None: devices = ctx.devices @@ -732,6 +811,22 @@ class GenericScanKernel(object): # can't overwrite any of the input. partial_scan_buffer_name = None + if self.input_fetch_exprs: + # FIXME need to insert code to handle input_fetch_exprs into + # the lookbehind update + raise NotImplementedError("input_fetch_exprs are not supported " + "with a segmented scan using a look-behind update " + "(e.g. an exclusive scan)") + + + for name, arg_name, ife_offset in input_fetch_exprs: + if ife_offset not in [0, -1]: + raise RuntimeError("input_fetch_expr offsets must either be 0 or -1") + + arg_ctypes = {} + for arg in self.parsed_args: + arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype) + if partial_scan_buffer_name is not None: self.partial_scan_buffer_idx, = [ i for i, arg in enumerate(self.parsed_args) @@ -751,6 +846,7 @@ class GenericScanKernel(object): index_type_max=str(np.iinfo(self.index_dtype).max), scan_ctype=dtype_to_ctype(dtype), is_segmented=self.is_segmented, + arg_ctypes=arg_ctypes, scan_expr=scan_expr.replace("\n", " "), neutral=neutral.replace("\n", " "), double_support=all( @@ -768,7 +864,9 @@ class GenericScanKernel(object): while True: candidate_scan_info = self.build_scan_kernel( max_scan_wg_size, arguments, input_expr, - is_i_segment_start_expr, is_first_level=True) + is_i_segment_start_expr, + input_fetch_exprs=input_fetch_exprs, + is_first_level=True) # Will this device actually let us execute this kernel # at the desired work group size? Building it is the @@ -795,7 +893,7 @@ class GenericScanKernel(object): # {{{ build second-level scan - second_level_arguments = [ + second_level_arguments = self.arguments.split(",") + [ "__global %s *interval_sums" % dtype_to_ctype(dtype)] second_level_build_kwargs = {} if self.is_segmented: @@ -816,6 +914,7 @@ class GenericScanKernel(object): max_scan_wg_size, arguments=", ".join(second_level_arguments), input_expr="interval_sums[i]", + input_fetch_exprs=[], is_first_level=False, **second_level_build_kwargs) @@ -856,7 +955,7 @@ class GenericScanKernel(object): # }}} def build_scan_kernel(self, max_wg_size, arguments, input_expr, - is_i_segment_start_expr, is_first_level): + is_i_segment_start_expr, input_fetch_exprs, is_first_level): scalar_arg_dtypes = _get_scalar_arg_dtypes(_parse_args(arguments)) # Thrust says that 128 is big enough for GT200 @@ -879,6 +978,7 @@ class GenericScanKernel(object): k_group_size=k_group_size, argument_signature=arguments.replace("\n", " "), is_i_segment_start_expr=is_i_segment_start_expr, + input_fetch_exprs=input_fetch_exprs, is_first_level=is_first_level, **self.code_variables)) @@ -972,10 +1072,10 @@ class GenericScanKernel(object): # can scan at most one interval assert interval_size >= num_intervals - scan2_args = (interval_results, interval_results) + scan2_args = data_args + [interval_results, interval_results] if self.is_segmented: scan2_args = scan2_args + [first_segment_start_in_interval] - scan2_args = scan2_args + (num_intervals, interval_size) + scan2_args = scan2_args + [num_intervals, interval_size] l2_info.kernel( queue, (1,), (l1_info.wg_size,), @@ -1053,7 +1153,8 @@ class ExclusiveScanKernel(_ScanKernelBase): # {{{ higher-level trickery @context_dependent_memoize -def _get_copy_if_kernel(ctx, dtype, predicate, scan_dtype, extra_args_types): +def _get_copy_if_kernel(ctx, dtype, predicate, scan_dtype, + extra_args_types, preamble): ctype = dtype_to_ctype(dtype) arguments = [ "__global %s *ary" % ctype, @@ -1071,10 +1172,10 @@ def _get_copy_if_kernel(ctx, dtype, predicate, scan_dtype, extra_args_types): output_statement=""" if (prev_item != item) out[item-1] = ary[i]; if (i+1 == N) *count = item; - """ - ) + """, + preamble=preamble) -def copy_if(ary, predicate, extra_args=[], queue=None): +def copy_if(ary, predicate, extra_args=[], queue=None, preamble=""): """ :arg extra_args: a list of tuples *(name, value)* specifying extra arguments to pass to the scan procedure. @@ -1087,17 +1188,20 @@ def copy_if(ary, predicate, extra_args=[], queue=None): extra_args_types = tuple((name, val.dtype) for name, val in extra_args) extra_args_values = tuple(val for name, val in extra_args) - knl = _get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype, extra_args_types) + knl = _get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype, + extra_args_types, preamble=preamble) out = cl_array.empty_like(ary) count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64) knl(ary, out, count, *extra_args_values, queue=queue) return out, count -def remove_if(ary, predicate, extra_args=[], queue=None): - return copy_if(ary, "!(%s)" % predicate, extra_args=extra_args, queue=queue) +def remove_if(ary, predicate, extra_args=[], queue=None, preamble=""): + return copy_if(ary, "!(%s)" % predicate, extra_args=extra_args, queue=queue, + preamble=preamble) @context_dependent_memoize -def _get_partition_kernel(ctx, dtype, predicate, scan_dtype, extra_args_types): +def _get_partition_kernel(ctx, dtype, predicate, scan_dtype, + extra_args_types, preamble): ctype = dtype_to_ctype(dtype) arguments = [ "__global %s *ary" % ctype, @@ -1119,10 +1223,10 @@ def _get_partition_kernel(ctx, dtype, predicate, scan_dtype, extra_args_types): else out_false[i-item] = ary[i]; if (i+1 == N) *count_true = item; - """ - ) + """, + preamble=preamble) -def partition(ary, predicate, extra_args=[], queue=None): +def partition(ary, predicate, extra_args=[], queue=None, preamble=""): """ :arg extra_args: a list of tuples *(name, value)* specifying extra arguments to pass to the scan procedure. @@ -1135,7 +1239,8 @@ def partition(ary, predicate, extra_args=[], queue=None): extra_args_types = tuple((name, val.dtype) for name, val in extra_args) extra_args_values = tuple(val for name, val in extra_args) - knl = _get_partition_kernel(ary.context, ary.dtype, predicate, scan_dtype, extra_args_types) + knl = _get_partition_kernel(ary.context, ary.dtype, predicate, scan_dtype, + extra_args_types, preamble) out_true = cl_array.empty_like(ary) out_false = cl_array.empty_like(ary) count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64) @@ -1143,35 +1248,53 @@ def partition(ary, predicate, extra_args=[], queue=None): return out_true, out_false, count @context_dependent_memoize -def _get_unique_by_key_kernel(ctx, dtype, key_expr, scan_dtype, extra_args_types): +def _get_unique_by_key_kernel(ctx, dtype, key_expr, scan_dtype, + extra_args_types, preamble): ctype = dtype_to_ctype(dtype) arguments = [ "__global %s *ary" % ctype, "__global %s *out" % ctype, - "__global unsigned long *count_true", + "__global unsigned long *count_unique", ] + [ "%s %s" % (dtype_to_ctype(arg_dtype), name) for name, arg_dtype in extra_args_types] + key_expr_define = "#define KEY_EXPR(a) %s\n" % key_expr.replace("\n", " ") return GenericScanKernel( ctx, dtype, arguments=", ".join(arguments), - input_expr="(%s) ? 1 : 0" % key_expr, + input_fetch_exprs=[ + ("ary_im1", "ary", -1), + ("ary_i", "ary", 0), + ], + input_expr="(i == 0 || KEY_EXPR(ary_im1) != KEY_EXPR(ary_i)) ? 1 : 0", scan_expr="a+b", neutral="0", output_statement=""" - if (prev_item != item) - out_true[item-1] = ary[i]; - else - out_false[i-item] = ary[i]; - if (i+1 == N) *count_true = item; - """) + if (prev_item != item) out[item-1] = ary[i]; + if (i+1 == N) *count_unique = item; + """, + preamble=preamble+"\n\n"+key_expr_define) -def unique_by_key(array, key_expr, **kwargs): +def unique_by_key(ary, key_expr="a", extra_args=[], queue=None, preamble=""): """ :arg extra_args: a list of tuples *(name, value)* specifying extra arguments to pass to the scan procedure. """ + if len(ary) > np.iinfo(np.uint32): + scan_dtype = np.uint64 + else: + scan_dtype = np.uint32 + + extra_args_types = tuple((name, val.dtype) for name, val in extra_args) + extra_args_values = tuple(val for name, val in extra_args) + + knl = _get_unique_by_key_kernel(ary.context, ary.dtype, key_expr, scan_dtype, + extra_args_types, preamble) + out = cl_array.empty_like(ary) + count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64) + knl(ary, out, count, *extra_args_values, queue=queue) + return out, count # }}} diff --git a/test/test_array.py b/test/test_array.py index 32221b69c5c53d8a9d768614e15c66eb4b549de3..a1cc402507974ebe13ba715541087da2fee36add 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -801,6 +801,27 @@ def test_partition(ctx_factory): assert (true_dev.get()[:count_true_dev] == true_host).all() assert (false_dev.get()[:n-count_true_dev] == false_host).all() +@pytools.test.mark_test.opencl +def test_unique(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + from pyopencl.clrandom import rand as clrand + for n in scan_test_counts: + a_dev = clrand(queue, (n,), dtype=np.int32, a=0, b=1000) + a = a_dev.get() + a = np.sort(a) + a_dev = cl_array.to_device(queue, a) + + a_unique_host = np.unique(a) + + from pyopencl.scan import unique_by_key + a_unique_dev, count_unique_dev = unique_by_key(a_dev) + + count_unique_dev = count_unique_dev.get() + + assert (a_unique_dev.get()[:count_unique_dev] == a_unique_host).all() + @pytools.test.mark_test.opencl def test_stride_preservation(ctx_factory): context = ctx_factory()