diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index 387e31a4d0a455368ca45bef287220b5f9ead72e..63190628fc28f09d268648fc25a88bd06d887f1b 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -392,6 +392,9 @@ RADIX_SORT_OUTPUT_STMT_TPL = Template(r"""//CL// # {{{ driver +# import hoisted here to be used as a default argument in the constructor +from pyopencl.scan import GenericScanKernel + class RadixSort(object): """Provides a general `radix sort <https://en.wikipedia.org/wiki/Radix_sort>`_ @@ -401,7 +404,7 @@ class RadixSort(object): """ def __init__(self, context, arguments, key_expr, sort_arg_names, bits_at_a_time=2, index_dtype=np.int32, key_dtype=np.uint32, - options=[]): + scan_kernel=GenericScanKernel, options=[]): """ :arg arguments: A string of comma-separated C argument declarations. If *arguments* is specified, then *input_expr* must also be @@ -469,8 +472,7 @@ class RadixSort(object): scan_preamble = preamble \ + RADIX_SORT_SCAN_PREAMBLE_TPL.render(**codegen_args) - from pyopencl.scan import GenericScanKernel - self.scan_kernel = GenericScanKernel( + self.scan_kernel = scan_kernel( context, scan_dtype, arguments=scan_arguments, input_expr="scan_t_from_value(%s, base_bit, i)" % key_expr,