diff --git a/pyopencl/bitonic_sort.py b/pyopencl/bitonic_sort.py
index a1d9b7e2a63c80e10bd8709b0e9e96aac1a3a3d5..a6e9b5a1d952b46a9d5df6c878a273c150099ab5 100644
--- a/pyopencl/bitonic_sort.py
+++ b/pyopencl/bitonic_sort.py
@@ -57,6 +57,10 @@ class BitonicSort(object):
     that is a power of 2.
 
     .. versionadded:: 2015.2
+
+    .. seealso:: :class:`pyopencl.algorithm.RadixSort`
+
+    .. autofunction:: __call__
     """
 
     kernels_srcs = {
@@ -70,16 +74,21 @@ class BitonicSort(object):
             'PML': _tmpl.ParallelMerge_Local
             }
 
-    def __init__(self, context, key_dtype, idx_dtype=None):
-        self.dtype = dtype_to_ctype(key_dtype)
+    def __init__(self, context):
         self.context = context
-        if idx_dtype is None:
-            self.idx_t = 'uint'  # Dummy
 
-        else:
-            self.idx_t = dtype_to_ctype(idx_dtype)
+    def __call__(self, arr, idx=None, queue=None, wait_for=None, axis=0):
+        """
+        :arg arr: the array to be sorted. Will be overwritten with the sorted array.
+        :arg idx: an array of indices to be tracked along with the sorting of *arr*
+        :arg queue: a :class:`pyopencl.CommandQueue`, defaults to the array's queue
+            if None
+        :arg wait_for: a list of :class:`pyopencl.Event` instances or None
+        :arg axis: the axis of the array by which to sort
+
+        :returns: a tuple (sorted_array, event)
+        """
 
-    def __call__(self, arr, idx=None, mkcpy=True, queue=None, wait_for=None, axis=0):
         if queue is None:
             queue = arr.queue
 
@@ -95,14 +104,17 @@ class BitonicSort(object):
         if not _is_power_of_2(arr.shape[axis]):
             raise ValueError("sorted array axis length must be a power of 2")
 
-        arr = arr.copy() if mkcpy else arr
-
         if idx is None:
             argsort = 0
         else:
             argsort = 1
 
-        run_queue = self.sort_b_prepare_wl(argsort, arr.shape, axis)
+        run_queue = self.sort_b_prepare_wl(
+                argsort,
+                arr.dtype,
+                idx.dtype if idx is not None else None, arr.shape,
+                axis)
+
         knl, nt, wg, aux = run_queue[0]
 
         if idx is not None:
@@ -143,7 +155,15 @@ class BitonicSort(object):
         return prg
 
     @memoize_method
-    def sort_b_prepare_wl(self, argsort, shape, axis):
+    def sort_b_prepare_wl(self, argsort, key_dtype, idx_dtype, shape, axis):
+        key_ctype = dtype_to_ctype(key_dtype)
+
+        if idx_dtype is None:
+            idx_ctype = 'uint'  # Dummy
+
+        else:
+            idx_ctype = dtype_to_ctype(idx_dtype)
+
         run_queue = []
         ds = int(shape[axis])
         size = reduce(mul, shape)
@@ -159,7 +179,7 @@ class BitonicSort(object):
         wg = min(ds, self.context.devices[0].max_work_group_size)
         length = wg >> 1
         prg = self.get_program(
-                'BLO', argsort, (1, 1, self.dtype, self.idx_t, ds, ns))
+                'BLO', argsort, (1, 1, key_ctype, idx_ctype, ds, ns))
         run_queue.append((prg.run, size, (wg,), True))
 
         while length < ds:
@@ -183,7 +203,7 @@ class BitonicSort(object):
                 nthreads = size >> ninc
 
                 prg = self.get_program(letter, argsort,
-                        (inc, direction, self.dtype, self.idx_t,  ds, ns))
+                        (inc, direction, key_ctype, idx_ctype,  ds, ns))
                 run_queue.append((prg.run, nthreads, None, False,))
                 inc >>= ninc
 
diff --git a/test/test_algorithm.py b/test/test_algorithm.py
index cc32b1cb0e0444a9cc451f9e983cdfa5019258b9..fca3f6f6a921ec1b1e458c5e11d014ff8cb30bc8 100644
--- a/test/test_algorithm.py
+++ b/test/test_algorithm.py
@@ -857,8 +857,8 @@ def test_bitonic_sort(ctx_factory, size, dtype):
     from pyopencl.bitonic_sort import BitonicSort
 
     s = clrandom.rand(queue, (2, size, 3,), dtype, luxury=None, a=0, b=1.0)
-    sorter = BitonicSort(ctx, s.dtype)
-    sgs, evt = sorter(s, axis=1)
+    sorter = BitonicSort(ctx)
+    sgs, evt = sorter(s.copy(), axis=1)
     assert np.array_equal(np.sort(s.get(), axis=1), sgs.get())
 
 
@@ -884,9 +884,9 @@ def test_bitonic_argsort(ctx_factory, size, dtype):
     index = cl_array.arange(queue, 0, size, 1, dtype=np.int32)
     m = clrandom.rand(queue, (size,), np.float32, luxury=None, a=0, b=1.0)
 
-    sorterm = BitonicSort(ctx, m.dtype, idx_dtype=index.dtype)
+    sorterm = BitonicSort(ctx)
 
-    ms, evt = sorterm(m, idx=index, axis=0)
+    ms, evt = sorterm(m.copy(), idx=index, axis=0)
 
     assert np.array_equal(np.sort(m.get()), ms.get())