diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index d7ddc73196c646da1f796a48af8b310edbf7d64c..fffbef3efe2c79090f856266f1bfdb8aff9c291d 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -276,17 +276,6 @@ RADIX_SORT_PREAMBLE_TPL = Template(r"""//CL// #endif <% - def get_count_branch(known_bits): - if len(known_bits) == bits: - return "s.c%s" % known_bits - - b = len(known_bits) - boundary_mnr = known_bits + "1" + (bits-b-1)*"0" - - return ("((mnr < %s) ? %s : %s)" % ( - int(boundary_mnr, 2), - get_count_branch(known_bits+"0"), - get_count_branch(known_bits+"1"))) %> index_t get_count(scan_t s, int mnr) @@ -375,7 +364,7 @@ class RadixSort(object): <https://en.wikipedia.org/wiki/Radix_sort>`_ on the compute device. """ def __init__(self, context, arguments, key_expr, sort_arg_names, - bits_at_a_time=4, index_dtype=np.int32, key_dtype=np.uint32, + bits_at_a_time=3, index_dtype=np.int32, key_dtype=np.uint32, options=[]): """ :arg arguments: A string of comma-separated C argument declarations. @@ -417,6 +406,17 @@ class RadixSort(object): if arg.name in sort_arg_names] + [ ScalarArg(np.int32, "base_bit") ]) + def get_count_branch(known_bits): + if len(known_bits) == self.bits: + return "s.c%s" % known_bits + + boundary_mnr = known_bits + "1" + (self.bits-len(known_bits)-1)*"0" + + return ("((mnr < %s) ? %s : %s)" % ( + int(boundary_mnr, 2), + get_count_branch(known_bits+"0"), + get_count_branch(known_bits+"1"))) + codegen_args = dict( bits=self.bits, key_ctype=dtype_to_ctype(self.key_dtype), @@ -426,6 +426,7 @@ class RadixSort(object): padded_bin=_padded_bin, scan_ctype=scan_ctype, sort_arg_names=sort_arg_names, + get_count_branch=get_count_branch, ) preamble = scan_t_cdecl+RADIX_SORT_PREAMBLE_TPL.render(**codegen_args) diff --git a/test/test_array.py b/test/test_array.py index 9e680727c5be30be7c33831ba7d3a207f58b999c..cdc1cfefd5b3ea3c6fb79fbafab4c839d9dace01 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -816,7 +816,9 @@ scan_test_counts = [ 2 ** 20 - 2 ** 18, 2 ** 20 - 2 ** 18 + 5, 2 ** 20 + 1, - 2 ** 20, 2 ** 24 + 2 ** 20, + 2 ** 23 + 3, + 2 ** 24 + 5 ] @pytools.test.mark_test.opencl @@ -1070,8 +1072,8 @@ def test_sort(ctx_factory): numpy_elapsed = numpy_end-dev_end dev_elapsed = dev_end-dev_start - print " dev: %.2f s numpy: %.2f ratio: %.1fx" % ( - dev_elapsed, numpy_elapsed, dev_elapsed/numpy_elapsed) + print " dev: %.2f MKeys/s numpy: %.2f MKeys/s ratio: %.2fx" % ( + 1e-6*n/dev_elapsed, 1e-6*n/numpy_elapsed, numpy_elapsed/dev_elapsed) assert (a_dev_sorted.get() == a_sorted).all()