From f73aaf39e3993dd1d400e1f2539ffa5d72bd3d55 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Mon, 9 May 2022 10:53:29 -0500 Subject: [PATCH] remove pytools.Record in cl.algorithm --- pyopencl/algorithm.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index c3cbf251..0294c2e2 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -29,13 +29,19 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from dataclasses import dataclass +from typing import Optional import numpy as np + import pyopencl as cl -import pyopencl.array # noqa -from pyopencl.scan import ScanTemplate +import pyopencl.array + +from pyopencl.elementwise import ElementwiseKernel +from pyopencl.scan import ScanTemplate, GenericScanKernel from pyopencl.tools import dtype_to_ctype, get_arg_offset_adjuster_code -from pytools import memoize, memoize_method, Record + +from pytools import memoize, memoize_method from mako.template import Template @@ -419,10 +425,6 @@ RADIX_SORT_OUTPUT_STMT_TPL = Template(r"""//CL// # {{{ driver -# import hoisted here to be used as a default argument in the constructor -from pyopencl.scan import GenericScanKernel - - class RadixSort: """Provides a general `radix sort <https://en.wikipedia.org/wiki/Radix_sort>`_ on the compute device. @@ -722,8 +724,14 @@ def _get_arg_list(arg_list, prefix=""): return result -class BuiltList(Record): - pass +@dataclass +class BuiltList: + count: Optional[int] + starts: Optional[pyopencl.array.Array] + lists: Optional[pyopencl.array.Array] + num_nonempty_lists: Optional[int] = None + nonempty_indices: Optional[pyopencl.array.Array] = None + compressed_indices: Optional[pyopencl.array.Array] = None class ListOfListsBuilder: @@ -877,7 +885,6 @@ class ListOfListsBuilder: @memoize_method def get_scan_kernel(self, index_dtype): - from pyopencl.scan import GenericScanKernel return GenericScanKernel( self.context, index_dtype, arguments="__global %s *ary" % dtype_to_ctype(index_dtype), @@ -897,7 +904,6 @@ class ListOfListsBuilder: """ arguments = Template(arguments) - from pyopencl.scan import GenericScanKernel return GenericScanKernel( self.context, index_dtype, arguments=arguments.render(index_t=dtype_to_ctype(index_dtype)), @@ -1313,8 +1319,11 @@ class ListOfListsBuilder: # {{{ key-value sorting -class _KernelInfo(Record): - pass +@dataclass(frozen=True) +class _KernelInfo: + by_target_sorter: RadixSort + start_finder: ElementwiseKernel + bound_propagation_scan: GenericScanKernel def _make_cl_int_literal(value, dtype): @@ -1356,7 +1365,6 @@ class KeyValueSorter: @memoize_method def get_kernels(self, key_dtype, value_dtype, starts_dtype): - from pyopencl.algorithm import RadixSort from pyopencl.tools import VectorArg, ScalarArg by_target_sorter = RadixSort( @@ -1387,7 +1395,6 @@ class KeyValueSorter: ), var_values=()) - from pyopencl.scan import GenericScanKernel bound_propagation_scan = GenericScanKernel( self.context, starts_dtype, arguments=[ -- GitLab