diff --git a/pyopencl/bitonic_sort.py b/pyopencl/bitonic_sort.py index a1d9b7e2a63c80e10bd8709b0e9e96aac1a3a3d5..a6e9b5a1d952b46a9d5df6c878a273c150099ab5 100644 --- a/pyopencl/bitonic_sort.py +++ b/pyopencl/bitonic_sort.py @@ -57,6 +57,10 @@ class BitonicSort(object): that is a power of 2. .. versionadded:: 2015.2 + + .. seealso:: :class:`pyopencl.algorithm.RadixSort` + + .. autofunction:: __call__ """ kernels_srcs = { @@ -70,16 +74,21 @@ class BitonicSort(object): 'PML': _tmpl.ParallelMerge_Local } - def __init__(self, context, key_dtype, idx_dtype=None): - self.dtype = dtype_to_ctype(key_dtype) + def __init__(self, context): self.context = context - if idx_dtype is None: - self.idx_t = 'uint' # Dummy - else: - self.idx_t = dtype_to_ctype(idx_dtype) + def __call__(self, arr, idx=None, queue=None, wait_for=None, axis=0): + """ + :arg arr: the array to be sorted. Will be overwritten with the sorted array. + :arg idx: an array of indices to be tracked along with the sorting of *arr* + :arg queue: a :class:`pyopencl.CommandQueue`, defaults to the array's queue + if None + :arg wait_for: a list of :class:`pyopencl.Event` instances or None + :arg axis: the axis of the array by which to sort + + :returns: a tuple (sorted_array, event) + """ - def __call__(self, arr, idx=None, mkcpy=True, queue=None, wait_for=None, axis=0): if queue is None: queue = arr.queue @@ -95,14 +104,17 @@ class BitonicSort(object): if not _is_power_of_2(arr.shape[axis]): raise ValueError("sorted array axis length must be a power of 2") - arr = arr.copy() if mkcpy else arr - if idx is None: argsort = 0 else: argsort = 1 - run_queue = self.sort_b_prepare_wl(argsort, arr.shape, axis) + run_queue = self.sort_b_prepare_wl( + argsort, + arr.dtype, + idx.dtype if idx is not None else None, arr.shape, + axis) + knl, nt, wg, aux = run_queue[0] if idx is not None: @@ -143,7 +155,15 @@ class BitonicSort(object): return prg @memoize_method - def sort_b_prepare_wl(self, argsort, shape, axis): + def sort_b_prepare_wl(self, argsort, key_dtype, idx_dtype, shape, axis): + key_ctype = dtype_to_ctype(key_dtype) + + if idx_dtype is None: + idx_ctype = 'uint' # Dummy + + else: + idx_ctype = dtype_to_ctype(idx_dtype) + run_queue = [] ds = int(shape[axis]) size = reduce(mul, shape) @@ -159,7 +179,7 @@ class BitonicSort(object): wg = min(ds, self.context.devices[0].max_work_group_size) length = wg >> 1 prg = self.get_program( - 'BLO', argsort, (1, 1, self.dtype, self.idx_t, ds, ns)) + 'BLO', argsort, (1, 1, key_ctype, idx_ctype, ds, ns)) run_queue.append((prg.run, size, (wg,), True)) while length < ds: @@ -183,7 +203,7 @@ class BitonicSort(object): nthreads = size >> ninc prg = self.get_program(letter, argsort, - (inc, direction, self.dtype, self.idx_t, ds, ns)) + (inc, direction, key_ctype, idx_ctype, ds, ns)) run_queue.append((prg.run, nthreads, None, False,)) inc >>= ninc diff --git a/test/test_algorithm.py b/test/test_algorithm.py index cc32b1cb0e0444a9cc451f9e983cdfa5019258b9..fca3f6f6a921ec1b1e458c5e11d014ff8cb30bc8 100644 --- a/test/test_algorithm.py +++ b/test/test_algorithm.py @@ -857,8 +857,8 @@ def test_bitonic_sort(ctx_factory, size, dtype): from pyopencl.bitonic_sort import BitonicSort s = clrandom.rand(queue, (2, size, 3,), dtype, luxury=None, a=0, b=1.0) - sorter = BitonicSort(ctx, s.dtype) - sgs, evt = sorter(s, axis=1) + sorter = BitonicSort(ctx) + sgs, evt = sorter(s.copy(), axis=1) assert np.array_equal(np.sort(s.get(), axis=1), sgs.get()) @@ -884,9 +884,9 @@ def test_bitonic_argsort(ctx_factory, size, dtype): index = cl_array.arange(queue, 0, size, 1, dtype=np.int32) m = clrandom.rand(queue, (size,), np.float32, luxury=None, a=0, b=1.0) - sorterm = BitonicSort(ctx, m.dtype, idx_dtype=index.dtype) + sorterm = BitonicSort(ctx) - ms, evt = sorterm(m, idx=index, axis=0) + ms, evt = sorterm(m.copy(), idx=index, axis=0) assert np.array_equal(np.sort(m.get()), ms.get())