diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index c3cbf25194fff7ebbc4b4704c6c8c5877a22dae5..0294c2e2e860a2fda212d1f36175f82719bc1970 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -29,13 +29,19 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from dataclasses import dataclass +from typing import Optional import numpy as np + import pyopencl as cl -import pyopencl.array # noqa -from pyopencl.scan import ScanTemplate +import pyopencl.array + +from pyopencl.elementwise import ElementwiseKernel +from pyopencl.scan import ScanTemplate, GenericScanKernel from pyopencl.tools import dtype_to_ctype, get_arg_offset_adjuster_code -from pytools import memoize, memoize_method, Record + +from pytools import memoize, memoize_method from mako.template import Template @@ -419,10 +425,6 @@ RADIX_SORT_OUTPUT_STMT_TPL = Template(r"""//CL// # {{{ driver -# import hoisted here to be used as a default argument in the constructor -from pyopencl.scan import GenericScanKernel - - class RadixSort: """Provides a general `radix sort <https://en.wikipedia.org/wiki/Radix_sort>`_ on the compute device. @@ -722,8 +724,14 @@ def _get_arg_list(arg_list, prefix=""): return result -class BuiltList(Record): - pass +@dataclass +class BuiltList: + count: Optional[int] + starts: Optional[pyopencl.array.Array] + lists: Optional[pyopencl.array.Array] + num_nonempty_lists: Optional[int] = None + nonempty_indices: Optional[pyopencl.array.Array] = None + compressed_indices: Optional[pyopencl.array.Array] = None class ListOfListsBuilder: @@ -877,7 +885,6 @@ class ListOfListsBuilder: @memoize_method def get_scan_kernel(self, index_dtype): - from pyopencl.scan import GenericScanKernel return GenericScanKernel( self.context, index_dtype, arguments="__global %s *ary" % dtype_to_ctype(index_dtype), @@ -897,7 +904,6 @@ class ListOfListsBuilder: """ arguments = Template(arguments) - from pyopencl.scan import GenericScanKernel return GenericScanKernel( self.context, index_dtype, arguments=arguments.render(index_t=dtype_to_ctype(index_dtype)), @@ -1313,8 +1319,11 @@ class ListOfListsBuilder: # {{{ key-value sorting -class _KernelInfo(Record): - pass +@dataclass(frozen=True) +class _KernelInfo: + by_target_sorter: RadixSort + start_finder: ElementwiseKernel + bound_propagation_scan: GenericScanKernel def _make_cl_int_literal(value, dtype): @@ -1356,7 +1365,6 @@ class KeyValueSorter: @memoize_method def get_kernels(self, key_dtype, value_dtype, starts_dtype): - from pyopencl.algorithm import RadixSort from pyopencl.tools import VectorArg, ScalarArg by_target_sorter = RadixSort( @@ -1387,7 +1395,6 @@ class KeyValueSorter: ), var_values=()) - from pyopencl.scan import GenericScanKernel bound_propagation_scan = GenericScanKernel( self.context, starts_dtype, arguments=[