From 04bf811eee0145c473f6acd389e1107128d236a0 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 31 Jul 2012 00:57:13 -0500
Subject: [PATCH] Sorting tweaks.

---
 pyopencl/algorithm.py | 25 +++++++++++++------------
 test/test_array.py    |  8 +++++---
 2 files changed, 18 insertions(+), 15 deletions(-)

diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py
index d7ddc731..fffbef3e 100644
--- a/pyopencl/algorithm.py
+++ b/pyopencl/algorithm.py
@@ -276,17 +276,6 @@ RADIX_SORT_PREAMBLE_TPL = Template(r"""//CL//
     #endif
 
     <%
-      def get_count_branch(known_bits):
-          if len(known_bits) == bits:
-              return "s.c%s" % known_bits
-
-          b = len(known_bits)
-          boundary_mnr = known_bits + "1" + (bits-b-1)*"0"
-
-          return ("((mnr < %s) ? %s : %s)" % (
-              int(boundary_mnr, 2),
-              get_count_branch(known_bits+"0"),
-              get_count_branch(known_bits+"1")))
     %>
 
     index_t get_count(scan_t s, int mnr)
@@ -375,7 +364,7 @@ class RadixSort(object):
     <https://en.wikipedia.org/wiki/Radix_sort>`_ on the compute device.
     """
     def __init__(self, context, arguments, key_expr, sort_arg_names,
-            bits_at_a_time=4, index_dtype=np.int32, key_dtype=np.uint32,
+            bits_at_a_time=3, index_dtype=np.int32, key_dtype=np.uint32,
             options=[]):
         """
         :arg arguments: A string of comma-separated C argument declarations.
@@ -417,6 +406,17 @@ class RadixSort(object):
                     if arg.name in sort_arg_names]
                 + [ ScalarArg(np.int32, "base_bit") ])
 
+        def get_count_branch(known_bits):
+            if len(known_bits) == self.bits:
+                return "s.c%s" % known_bits
+
+            boundary_mnr = known_bits + "1" + (self.bits-len(known_bits)-1)*"0"
+
+            return ("((mnr < %s) ? %s : %s)" % (
+                int(boundary_mnr, 2),
+                get_count_branch(known_bits+"0"),
+                get_count_branch(known_bits+"1")))
+
         codegen_args = dict(
                 bits=self.bits,
                 key_ctype=dtype_to_ctype(self.key_dtype),
@@ -426,6 +426,7 @@ class RadixSort(object):
                 padded_bin=_padded_bin,
                 scan_ctype=scan_ctype,
                 sort_arg_names=sort_arg_names,
+                get_count_branch=get_count_branch,
                 )
 
         preamble = scan_t_cdecl+RADIX_SORT_PREAMBLE_TPL.render(**codegen_args)
diff --git a/test/test_array.py b/test/test_array.py
index 9e680727..cdc1cfef 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -816,7 +816,9 @@ scan_test_counts = [
     2 ** 20 - 2 ** 18,
     2 ** 20 - 2 ** 18 + 5,
     2 ** 20 + 1,
-    2 ** 20, 2 ** 24
+    2 ** 20,
+    2 ** 23 + 3,
+    2 ** 24 + 5
     ]
 
 @pytools.test.mark_test.opencl
@@ -1070,8 +1072,8 @@ def test_sort(ctx_factory):
 
         numpy_elapsed = numpy_end-dev_end
         dev_elapsed = dev_end-dev_start
-        print  "  dev: %.2f s numpy: %.2f ratio: %.1fx" % (
-                dev_elapsed, numpy_elapsed, dev_elapsed/numpy_elapsed)
+        print  "  dev: %.2f MKeys/s numpy: %.2f MKeys/s ratio: %.2fx" % (
+                1e-6*n/dev_elapsed, 1e-6*n/numpy_elapsed, numpy_elapsed/dev_elapsed)
         assert (a_dev_sorted.get() == a_sorted).all()
 
 
-- 
GitLab