diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py
index 387e31a4d0a455368ca45bef287220b5f9ead72e..63190628fc28f09d268648fc25a88bd06d887f1b 100644
--- a/pyopencl/algorithm.py
+++ b/pyopencl/algorithm.py
@@ -392,6 +392,9 @@ RADIX_SORT_OUTPUT_STMT_TPL = Template(r"""//CL//
 
 
 # {{{ driver
+# import hoisted here to be used as a default argument in the constructor
+from pyopencl.scan import GenericScanKernel
+
 
 class RadixSort(object):
     """Provides a general `radix sort <https://en.wikipedia.org/wiki/Radix_sort>`_
@@ -401,7 +404,7 @@ class RadixSort(object):
     """
     def __init__(self, context, arguments, key_expr, sort_arg_names,
             bits_at_a_time=2, index_dtype=np.int32, key_dtype=np.uint32,
-            options=[]):
+            scan_kernel=GenericScanKernel, options=[]):
         """
         :arg arguments: A string of comma-separated C argument declarations.
             If *arguments* is specified, then *input_expr* must also be
@@ -469,8 +472,7 @@ class RadixSort(object):
         scan_preamble = preamble \
                 + RADIX_SORT_SCAN_PREAMBLE_TPL.render(**codegen_args)
 
-        from pyopencl.scan import GenericScanKernel
-        self.scan_kernel = GenericScanKernel(
+        self.scan_kernel = scan_kernel(
                 context, scan_dtype,
                 arguments=scan_arguments,
                 input_expr="scan_t_from_value(%s, base_bit, i)" % key_expr,