From f73aaf39e3993dd1d400e1f2539ffa5d72bd3d55 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Mon, 9 May 2022 10:53:29 -0500
Subject: [PATCH] remove pytools.Record in cl.algorithm

---
 pyopencl/algorithm.py | 37 ++++++++++++++++++++++---------------
 1 file changed, 22 insertions(+), 15 deletions(-)

diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py
index c3cbf251..0294c2e2 100644
--- a/pyopencl/algorithm.py
+++ b/pyopencl/algorithm.py
@@ -29,13 +29,19 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
 OTHER DEALINGS IN THE SOFTWARE.
 """
 
+from dataclasses import dataclass
+from typing import Optional
 
 import numpy as np
+
 import pyopencl as cl
-import pyopencl.array  # noqa
-from pyopencl.scan import ScanTemplate
+import pyopencl.array
+
+from pyopencl.elementwise import ElementwiseKernel
+from pyopencl.scan import ScanTemplate, GenericScanKernel
 from pyopencl.tools import dtype_to_ctype, get_arg_offset_adjuster_code
-from pytools import memoize, memoize_method, Record
+
+from pytools import memoize, memoize_method
 from mako.template import Template
 
 
@@ -419,10 +425,6 @@ RADIX_SORT_OUTPUT_STMT_TPL = Template(r"""//CL//
 
 # {{{ driver
 
-# import hoisted here to be used as a default argument in the constructor
-from pyopencl.scan import GenericScanKernel
-
-
 class RadixSort:
     """Provides a general `radix sort <https://en.wikipedia.org/wiki/Radix_sort>`_
     on the compute device.
@@ -722,8 +724,14 @@ def _get_arg_list(arg_list, prefix=""):
     return result
 
 
-class BuiltList(Record):
-    pass
+@dataclass
+class BuiltList:
+    count: Optional[int]
+    starts: Optional[pyopencl.array.Array]
+    lists: Optional[pyopencl.array.Array]
+    num_nonempty_lists: Optional[int] = None
+    nonempty_indices: Optional[pyopencl.array.Array] = None
+    compressed_indices: Optional[pyopencl.array.Array] = None
 
 
 class ListOfListsBuilder:
@@ -877,7 +885,6 @@ class ListOfListsBuilder:
 
     @memoize_method
     def get_scan_kernel(self, index_dtype):
-        from pyopencl.scan import GenericScanKernel
         return GenericScanKernel(
                 self.context, index_dtype,
                 arguments="__global %s *ary" % dtype_to_ctype(index_dtype),
@@ -897,7 +904,6 @@ class ListOfListsBuilder:
         """
         arguments = Template(arguments)
 
-        from pyopencl.scan import GenericScanKernel
         return GenericScanKernel(
                 self.context, index_dtype,
                 arguments=arguments.render(index_t=dtype_to_ctype(index_dtype)),
@@ -1313,8 +1319,11 @@ class ListOfListsBuilder:
 
 # {{{ key-value sorting
 
-class _KernelInfo(Record):
-    pass
+@dataclass(frozen=True)
+class _KernelInfo:
+    by_target_sorter: RadixSort
+    start_finder: ElementwiseKernel
+    bound_propagation_scan: GenericScanKernel
 
 
 def _make_cl_int_literal(value, dtype):
@@ -1356,7 +1365,6 @@ class KeyValueSorter:
 
     @memoize_method
     def get_kernels(self, key_dtype, value_dtype, starts_dtype):
-        from pyopencl.algorithm import RadixSort
         from pyopencl.tools import VectorArg, ScalarArg
 
         by_target_sorter = RadixSort(
@@ -1387,7 +1395,6 @@ class KeyValueSorter:
                             ),
                         var_values=())
 
-        from pyopencl.scan import GenericScanKernel
         bound_propagation_scan = GenericScanKernel(
                 self.context, starts_dtype,
                 arguments=[
-- 
GitLab