From e1d703956db154bcbaa339f6bcd019a243239949 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 31 Jul 2012 00:21:41 -0400
Subject: [PATCH] Replace MSD radix sort with LSD radix sort, faster and
 simpler.

---
 pyopencl/algorithm.py | 173 ++++++++++--------------------------------
 test/test_array.py    |  16 +++-
 2 files changed, 53 insertions(+), 136 deletions(-)

diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py
index 0aa4a02f..d7ddc731 100644
--- a/pyopencl/algorithm.py
+++ b/pyopencl/algorithm.py
@@ -251,11 +251,6 @@ def _make_sort_scan_type(device, bits, index_dtype):
     for mnr in range(2**bits):
         fields.append(('c%s' % _padded_bin(mnr, bits), index_dtype))
 
-    fields.append(("segment_start_index", index_dtype))
-    fields.append(("bin_nr", np.uint8))
-
-    assert bits <= 8
-
     dtype = np.dtype(fields)
 
     name = "pyopencl_sort_scan_%s_%dbits_t" % (
@@ -298,6 +293,9 @@ RADIX_SORT_PREAMBLE_TPL = Template(r"""//CL//
     {
         return ${get_count_branch("")};
     }
+
+    #define BIN_NR(key_arg) ((key_arg >> base_bit) & ${2**bits - 1})
+
 """, strict_undefined=True)
 
 # }}}
@@ -311,106 +309,60 @@ RADIX_SORT_SCAN_PREAMBLE_TPL = Template(r"""//CL//
         %for mnr in range(2**bits):
             result.c${padded_bin(mnr, bits)} = 0;
         %endfor
-        result.segment_start_index = ${index_type_max};
         return result;
     }
 
-    // considers bits start_bit,start_bit-1,...,start_bit-bits+1
+    // considers bits (base_bit+bits-1, ..., base_bit)
     scan_t scan_t_from_value(
         key_t key,
-        int start_bit,
+        int base_bit,
         int i
     )
     {
         // extract relevant bit range
-        key_t bin_nr = (key >> (start_bit-${bits}+1)) & ${2**bits - 1};
+        key_t bin_nr = BIN_NR(key);
 
         dbg_printf(("i: %d key:%d bin_nr:%d\n", i, key, bin_nr));
 
         scan_t result;
         %for mnr in range(2**bits):
-            <% field = "c"+padded_bin(mnr, bits) %>
-            result.${field} = (bin_nr == ${mnr});
+            result.c${padded_bin(mnr, bits)} = (bin_nr == ${mnr});
         %endfor
-        result.bin_nr = bin_nr;
-        result.segment_start_index = i;
 
         return result;
     }
 
     scan_t scan_t_add(scan_t a, scan_t b, bool across_seg_boundary)
     {
-        if (!across_seg_boundary)
-        {
-            %for mnr in range(2**bits):
-                <% field = "c"+padded_bin(mnr, bits) %>
-                b.${field} = a.${field} + b.${field};
-            %endfor
-            b.segment_start_index =
-                min(a.segment_start_index, b.segment_start_index);
-        }
-
-        // directly use b.bin_nr
+        %for mnr in range(2**bits):
+            <% field = "c"+padded_bin(mnr, bits) %>
+            b.${field} = a.${field} + b.${field};
+        %endfor
+
         return b;
     }
 """, strict_undefined=True)
 
 RADIX_SORT_OUTPUT_STMT_TPL = Template(r"""//CL//
     {
-        /* Am I the last particle in my current box? */
-        /* NOTE: assumes past-end seg flag is set to one! */
-        if (seg_start_flags[i+1])
-        {
-            %for mnr in range(2**bits):
-                segment_bin_counts[
-                    item.segment_start_index + ${mnr}*padded_n]
-                        = item.c${padded_bin(mnr, bits)};
-            %endfor
-        }
-
-        bin_counts[i] = get_count(item, item.bin_nr);
-        dbg_printf(("bins %d (%d): %d %d %d %d\n",
-            i, item.bin_nr, item.c00, item.c01, item.c10, item.c11));
-
-        bin_number[i] = item.bin_nr;
-        segment_starts[i] = item.segment_start_index;
-    }
-""", strict_undefined=True)
-
-# }}}
-
-# {{{ reorder kernel
-
-RADIX_POSTPROC_KERNEL_TPL =  Template(r"""//CL//
-    index_t segment_start = segment_starts[i];
-
-    unsigned char my_bin_nr = bin_number[i];
-
-    index_t previous_segments_size = 0;
-    %for mnr in range(2**bits):
-        previous_segments_size +=
-            (my_bin_nr > ${mnr})
-                ? segment_bin_counts[segment_start + ${mnr}*padded_n]
-                : 0;
-    %endfor
-
-    index_t my_bin_index = bin_counts[i]-1;
-    index_t tgt_idx =
-        segment_start
-        + previous_segments_size
-        + my_bin_index;
+        key_t key = ${key_expr};
+        key_t my_bin_nr = BIN_NR(key);
 
-    dbg_printf(("moving %d -> %d\n", i, tgt_idx));
-    dbg_printf(("my_bin_index %d\n", my_bin_index));
+        index_t previous_bins_size = 0;
+        %for mnr in range(2**bits):
+            previous_bins_size +=
+                (my_bin_nr > ${mnr})
+                    ? last_item.c${padded_bin(mnr, bits)}
+                    : 0;
+        %endfor
 
-    %for arg_name in sort_arg_names:
-        sorted_${arg_name}[tgt_idx] = ${arg_name}[i];
-    %endfor
+        index_t tgt_idx =
+            previous_bins_size
+            + get_count(item, my_bin_nr) - 1;
 
-    /* Am I at the start of a new segment? */
-    if (my_bin_index == 0)
-    {
-        seg_start_flags[tgt_idx] = 1;
+        %for arg_name in sort_arg_names:
+            sorted_${arg_name}[tgt_idx] = ${arg_name}[i];
+        %endfor
     }
 """, strict_undefined=True)
 
@@ -459,20 +411,16 @@ class RadixSort(object):
                 _make_sort_scan_type(context.devices[0], self.bits, self.index_dtype)
 
         from pyopencl.tools import VectorArg, ScalarArg
-        scan_arguments = list(self.arguments) + [
-                VectorArg(np.uint8, "seg_start_flags"),
-                VectorArg(self.index_dtype, "segment_starts"), # [n]
-                VectorArg(self.index_dtype, "bin_counts"), # [n]
-                VectorArg(self.index_dtype, "segment_bin_counts"), # [nbins * pad(n)]
-                VectorArg(np.uint8, "bin_number"), # [n]
-
-                ScalarArg(np.int32, "start_bit"),
-                ScalarArg(self.index_dtype, "padded_n"),
-                ]
+        scan_arguments = (
+                list(self.arguments)
+                + [VectorArg(arg.dtype, "sorted_"+arg.name) for arg in self.arguments
+                    if arg.name in sort_arg_names]
+                + [ ScalarArg(np.int32, "base_bit") ])
 
         codegen_args = dict(
                 bits=self.bits,
                 key_ctype=dtype_to_ctype(self.key_dtype),
+                key_expr=key_expr,
                 index_ctype=dtype_to_ctype(self.index_dtype),
                 index_type_max=np.iinfo(self.index_dtype).max,
                 padded_bin=_padded_bin,
@@ -487,24 +435,12 @@ class RadixSort(object):
         self.scan_kernel = GenericScanKernel(
                 context, scan_dtype,
                 arguments=scan_arguments,
-                input_expr="scan_t_from_value(%s, start_bit, i)" % key_expr,
+                input_expr="scan_t_from_value(%s, base_bit, i)" % key_expr,
                 scan_expr="scan_t_add(a, b, across_seg_boundary)",
                 neutral="scan_t_neutral()",
-                is_segment_start_expr="seg_start_flags[i]",
                 output_statement=RADIX_SORT_OUTPUT_STMT_TPL.render(**codegen_args),
                 preamble=scan_preamble, options=self.options)
 
-        postproc_kernel_source = RADIX_POSTPROC_KERNEL_TPL.render(**codegen_args)
-
-        from pyopencl.elementwise import ElementwiseKernel
-        self.postproc_kernel = ElementwiseKernel(
-                context,
-                scan_arguments
-                + [VectorArg(arg.dtype, "sorted_"+arg.name) for arg in self.arguments
-                    if arg.name in sort_arg_names],
-                str(postproc_kernel_source), name="postproc",
-                preamble=str(preamble), options=self.options)
-
         for i, arg in enumerate(self.arguments):
             if isinstance(arg, VectorArg):
                 self.first_array_arg_idx = i
@@ -541,56 +477,27 @@ class RadixSort(object):
         if queue is None:
             queue = args[self.first_array_arg_idx].queue
 
-        from pytools import div_ceil
-        padded_n = div_ceil(n, 256)*256
-
-        nbins = 2**self.bits
-        seg_start_flags = cl.array.zeros(queue, n+1, dtype=np.uint8,
-                allocator=allocator)
-
-        # set last seg_start_flag to 1
-        cl.enqueue_copy(queue, seg_start_flags.data,
-                seg_start_flags.dtype.type(1),
-                device_offset=seg_start_flags.dtype.itemsize*n)
-
-        segment_starts = cl.array.zeros(queue, n, dtype=self.index_dtype,
-                allocator=allocator)
-        bin_counts = cl.array.empty(queue, n, dtype=self.index_dtype,
-                allocator=allocator)
-        segment_bin_counts = cl.array.empty(queue, padded_n*nbins, dtype=self.index_dtype,
-                allocator=allocator)
-        bin_number = cl.array.empty(queue, n, dtype=np.uint8,
-                allocator=allocator)
-
         args = list(args)
 
         kwargs = dict(queue=queue)
 
-        start_bit = int(key_bits) - 1
-        while start_bit >= 0:
-            scan_args = args + [
-                    seg_start_flags,
-                    segment_starts,
-                    bin_counts,
-                    segment_bin_counts,
-                    bin_number,
-                    start_bit, padded_n]
-
-            self.scan_kernel(*scan_args, **kwargs)
-
+        base_bit = 0
+        while base_bit < key_bits:
             sorted_args = [
                     cl.array.empty(queue, n, arg_descr.dtype, allocator=allocator)
                     for arg_descr in self.arguments
                     if arg_descr.name in self.sort_arg_names]
 
-            self.postproc_kernel(*(scan_args+sorted_args), **kwargs)
+            scan_args = args + sorted_args + [base_bit]
+
+            self.scan_kernel(*scan_args, **kwargs)
 
             # substitute sorted
             for i, arg_descr in enumerate(self.arguments):
                 if arg_descr.name in self.sort_arg_names:
                     args[i] = sorted_args[self.sort_arg_names.index(arg_descr.name)]
 
-            start_bit -= self.bits
+            base_bit += self.bits
 
         return [arg_val
                 for arg_descr, arg_val in zip(self.arguments, args)
diff --git a/test/test_array.py b/test/test_array.py
index 0fde15a4..9e680727 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -1050,18 +1050,28 @@ def test_sort(ctx_factory):
     from pyopencl.clrandom import RanluxGenerator
     rng = RanluxGenerator(queue, seed=15)
 
+    from time import time
+
     for n in scan_test_counts:
         print n
 
-        print "rng"
+        print "  rng"
         a_dev = rng.uniform(queue, (n,), dtype=dtype, a=0, b=2**16)
         a = a_dev.get()
 
-        print "device"
+        dev_start = time()
+        print "  device"
         a_dev_sorted, = sort(a_dev, key_bits=16)
         queue.finish()
-        print "numpy"
+        dev_end = time()
+        print "  numpy"
         a_sorted = np.sort(a)
+        numpy_end = time()
+
+        numpy_elapsed = numpy_end-dev_end
+        dev_elapsed = dev_end-dev_start
+        print  "  dev: %.2f s numpy: %.2f ratio: %.1fx" % (
+                dev_elapsed, numpy_elapsed, dev_elapsed/numpy_elapsed)
         assert (a_dev_sorted.get() == a_sorted).all()
 
 
-- 
GitLab