From e1d703956db154bcbaa339f6bcd019a243239949 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 31 Jul 2012 00:21:41 -0400 Subject: [PATCH] Replace MSD radix sort with LSD radix sort, faster and simpler. --- pyopencl/algorithm.py | 173 ++++++++++-------------------------------- test/test_array.py | 16 +++- 2 files changed, 53 insertions(+), 136 deletions(-) diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index 0aa4a02f..d7ddc731 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -251,11 +251,6 @@ def _make_sort_scan_type(device, bits, index_dtype): for mnr in range(2**bits): fields.append(('c%s' % _padded_bin(mnr, bits), index_dtype)) - fields.append(("segment_start_index", index_dtype)) - fields.append(("bin_nr", np.uint8)) - - assert bits <= 8 - dtype = np.dtype(fields) name = "pyopencl_sort_scan_%s_%dbits_t" % ( @@ -298,6 +293,9 @@ RADIX_SORT_PREAMBLE_TPL = Template(r"""//CL// { return ${get_count_branch("")}; } + + #define BIN_NR(key_arg) ((key_arg >> base_bit) & ${2**bits - 1}) + """, strict_undefined=True) # }}} @@ -311,106 +309,60 @@ RADIX_SORT_SCAN_PREAMBLE_TPL = Template(r"""//CL// %for mnr in range(2**bits): result.c${padded_bin(mnr, bits)} = 0; %endfor - result.segment_start_index = ${index_type_max}; return result; } - // considers bits start_bit,start_bit-1,...,start_bit-bits+1 + // considers bits (base_bit+bits-1, ..., base_bit) scan_t scan_t_from_value( key_t key, - int start_bit, + int base_bit, int i ) { // extract relevant bit range - key_t bin_nr = (key >> (start_bit-${bits}+1)) & ${2**bits - 1}; + key_t bin_nr = BIN_NR(key); dbg_printf(("i: %d key:%d bin_nr:%d\n", i, key, bin_nr)); scan_t result; %for mnr in range(2**bits): - <% field = "c"+padded_bin(mnr, bits) %> - result.${field} = (bin_nr == ${mnr}); + result.c${padded_bin(mnr, bits)} = (bin_nr == ${mnr}); %endfor - result.bin_nr = bin_nr; - result.segment_start_index = i; return result; } scan_t scan_t_add(scan_t a, scan_t b, bool across_seg_boundary) { - if (!across_seg_boundary) - { - %for mnr in range(2**bits): - <% field = "c"+padded_bin(mnr, bits) %> - b.${field} = a.${field} + b.${field}; - %endfor - b.segment_start_index = - min(a.segment_start_index, b.segment_start_index); - } - - // directly use b.bin_nr + %for mnr in range(2**bits): + <% field = "c"+padded_bin(mnr, bits) %> + b.${field} = a.${field} + b.${field}; + %endfor + return b; } """, strict_undefined=True) RADIX_SORT_OUTPUT_STMT_TPL = Template(r"""//CL// { - /* Am I the last particle in my current box? */ - /* NOTE: assumes past-end seg flag is set to one! */ - if (seg_start_flags[i+1]) - { - %for mnr in range(2**bits): - segment_bin_counts[ - item.segment_start_index + ${mnr}*padded_n] - = item.c${padded_bin(mnr, bits)}; - %endfor - } - - bin_counts[i] = get_count(item, item.bin_nr); - dbg_printf(("bins %d (%d): %d %d %d %d\n", - i, item.bin_nr, item.c00, item.c01, item.c10, item.c11)); - - bin_number[i] = item.bin_nr; - segment_starts[i] = item.segment_start_index; - } -""", strict_undefined=True) - -# }}} - -# {{{ reorder kernel - -RADIX_POSTPROC_KERNEL_TPL = Template(r"""//CL// - index_t segment_start = segment_starts[i]; - - unsigned char my_bin_nr = bin_number[i]; - - index_t previous_segments_size = 0; - %for mnr in range(2**bits): - previous_segments_size += - (my_bin_nr > ${mnr}) - ? segment_bin_counts[segment_start + ${mnr}*padded_n] - : 0; - %endfor - - index_t my_bin_index = bin_counts[i]-1; - index_t tgt_idx = - segment_start - + previous_segments_size - + my_bin_index; + key_t key = ${key_expr}; + key_t my_bin_nr = BIN_NR(key); - dbg_printf(("moving %d -> %d\n", i, tgt_idx)); - dbg_printf(("my_bin_index %d\n", my_bin_index)); + index_t previous_bins_size = 0; + %for mnr in range(2**bits): + previous_bins_size += + (my_bin_nr > ${mnr}) + ? last_item.c${padded_bin(mnr, bits)} + : 0; + %endfor - %for arg_name in sort_arg_names: - sorted_${arg_name}[tgt_idx] = ${arg_name}[i]; - %endfor + index_t tgt_idx = + previous_bins_size + + get_count(item, my_bin_nr) - 1; - /* Am I at the start of a new segment? */ - if (my_bin_index == 0) - { - seg_start_flags[tgt_idx] = 1; + %for arg_name in sort_arg_names: + sorted_${arg_name}[tgt_idx] = ${arg_name}[i]; + %endfor } """, strict_undefined=True) @@ -459,20 +411,16 @@ class RadixSort(object): _make_sort_scan_type(context.devices[0], self.bits, self.index_dtype) from pyopencl.tools import VectorArg, ScalarArg - scan_arguments = list(self.arguments) + [ - VectorArg(np.uint8, "seg_start_flags"), - VectorArg(self.index_dtype, "segment_starts"), # [n] - VectorArg(self.index_dtype, "bin_counts"), # [n] - VectorArg(self.index_dtype, "segment_bin_counts"), # [nbins * pad(n)] - VectorArg(np.uint8, "bin_number"), # [n] - - ScalarArg(np.int32, "start_bit"), - ScalarArg(self.index_dtype, "padded_n"), - ] + scan_arguments = ( + list(self.arguments) + + [VectorArg(arg.dtype, "sorted_"+arg.name) for arg in self.arguments + if arg.name in sort_arg_names] + + [ ScalarArg(np.int32, "base_bit") ]) codegen_args = dict( bits=self.bits, key_ctype=dtype_to_ctype(self.key_dtype), + key_expr=key_expr, index_ctype=dtype_to_ctype(self.index_dtype), index_type_max=np.iinfo(self.index_dtype).max, padded_bin=_padded_bin, @@ -487,24 +435,12 @@ class RadixSort(object): self.scan_kernel = GenericScanKernel( context, scan_dtype, arguments=scan_arguments, - input_expr="scan_t_from_value(%s, start_bit, i)" % key_expr, + input_expr="scan_t_from_value(%s, base_bit, i)" % key_expr, scan_expr="scan_t_add(a, b, across_seg_boundary)", neutral="scan_t_neutral()", - is_segment_start_expr="seg_start_flags[i]", output_statement=RADIX_SORT_OUTPUT_STMT_TPL.render(**codegen_args), preamble=scan_preamble, options=self.options) - postproc_kernel_source = RADIX_POSTPROC_KERNEL_TPL.render(**codegen_args) - - from pyopencl.elementwise import ElementwiseKernel - self.postproc_kernel = ElementwiseKernel( - context, - scan_arguments - + [VectorArg(arg.dtype, "sorted_"+arg.name) for arg in self.arguments - if arg.name in sort_arg_names], - str(postproc_kernel_source), name="postproc", - preamble=str(preamble), options=self.options) - for i, arg in enumerate(self.arguments): if isinstance(arg, VectorArg): self.first_array_arg_idx = i @@ -541,56 +477,27 @@ class RadixSort(object): if queue is None: queue = args[self.first_array_arg_idx].queue - from pytools import div_ceil - padded_n = div_ceil(n, 256)*256 - - nbins = 2**self.bits - seg_start_flags = cl.array.zeros(queue, n+1, dtype=np.uint8, - allocator=allocator) - - # set last seg_start_flag to 1 - cl.enqueue_copy(queue, seg_start_flags.data, - seg_start_flags.dtype.type(1), - device_offset=seg_start_flags.dtype.itemsize*n) - - segment_starts = cl.array.zeros(queue, n, dtype=self.index_dtype, - allocator=allocator) - bin_counts = cl.array.empty(queue, n, dtype=self.index_dtype, - allocator=allocator) - segment_bin_counts = cl.array.empty(queue, padded_n*nbins, dtype=self.index_dtype, - allocator=allocator) - bin_number = cl.array.empty(queue, n, dtype=np.uint8, - allocator=allocator) - args = list(args) kwargs = dict(queue=queue) - start_bit = int(key_bits) - 1 - while start_bit >= 0: - scan_args = args + [ - seg_start_flags, - segment_starts, - bin_counts, - segment_bin_counts, - bin_number, - start_bit, padded_n] - - self.scan_kernel(*scan_args, **kwargs) - + base_bit = 0 + while base_bit < key_bits: sorted_args = [ cl.array.empty(queue, n, arg_descr.dtype, allocator=allocator) for arg_descr in self.arguments if arg_descr.name in self.sort_arg_names] - self.postproc_kernel(*(scan_args+sorted_args), **kwargs) + scan_args = args + sorted_args + [base_bit] + + self.scan_kernel(*scan_args, **kwargs) # substitute sorted for i, arg_descr in enumerate(self.arguments): if arg_descr.name in self.sort_arg_names: args[i] = sorted_args[self.sort_arg_names.index(arg_descr.name)] - start_bit -= self.bits + base_bit += self.bits return [arg_val for arg_descr, arg_val in zip(self.arguments, args) diff --git a/test/test_array.py b/test/test_array.py index 0fde15a4..9e680727 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -1050,18 +1050,28 @@ def test_sort(ctx_factory): from pyopencl.clrandom import RanluxGenerator rng = RanluxGenerator(queue, seed=15) + from time import time + for n in scan_test_counts: print n - print "rng" + print " rng" a_dev = rng.uniform(queue, (n,), dtype=dtype, a=0, b=2**16) a = a_dev.get() - print "device" + dev_start = time() + print " device" a_dev_sorted, = sort(a_dev, key_bits=16) queue.finish() - print "numpy" + dev_end = time() + print " numpy" a_sorted = np.sort(a) + numpy_end = time() + + numpy_elapsed = numpy_end-dev_end + dev_elapsed = dev_end-dev_start + print " dev: %.2f s numpy: %.2f ratio: %.1fx" % ( + dev_elapsed, numpy_elapsed, dev_elapsed/numpy_elapsed) assert (a_dev_sorted.get() == a_sorted).all() -- GitLab