From b78495f14e0e3219aad5689f6ed9ea23fdb2715f Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Mon, 9 May 2022 10:32:18 -0500
Subject: [PATCH] remove pytools.Record in cl.reduction

---
 pyopencl/reduction.py | 56 ++++++++++++++++++++++++++-----------------
 1 file changed, 34 insertions(+), 22 deletions(-)

diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py
index e8f69342..f8853feb 100644
--- a/pyopencl/reduction.py
+++ b/pyopencl/reduction.py
@@ -29,13 +29,16 @@ Based on code/ideas by Mark Harris <mharris@nvidia.com>.
 None of the original source code remains.
 """
 
+from dataclasses import dataclass
+from typing import List
+
+import numpy as np
 
 import pyopencl as cl
 from pyopencl.tools import (
-        context_dependent_memoize,
-        dtype_to_ctype, KernelTemplateBase,
+        Argument, KernelTemplateBase,
+        context_dependent_memoize, dtype_to_ctype,
         _process_code_for_macro)
-import numpy as np
 
 
 # {{{ kernel source
@@ -120,6 +123,17 @@ KERNEL = r"""//CL//
 
 # {{{ internal codegen frontends
 
+@dataclass(frozen=True)
+class _ReductionInfo:
+    context: cl.Context
+    source: str
+    group_size: int
+
+    program: cl.Program
+    kernel: cl.Kernel
+    arg_types: List[Argument]
+
+
 def _get_reduction_source(
         ctx, out_type, out_type_size,
         neutral, reduce_expr, map_expr, parsed_args,
@@ -175,15 +189,7 @@ def _get_reduction_source(
         double_support=all(has_double_support(dev) for dev in devices),
         ))
 
-    from pytools import Record
-
-    class ReductionInfo(Record):
-        pass
-
-    return ReductionInfo(
-            context=ctx,
-            source=src,
-            group_size=group_size)
+    return src, group_size
 
 
 def get_reduction_kernel(stage,
@@ -213,24 +219,30 @@ def get_reduction_kernel(stage,
                 [VectorArg(dtype_out, "pyopencl_reduction_inp")]
                 + arguments)
 
-    inf = _get_reduction_source(
+    source, group_size = _get_reduction_source(
             ctx, dtype_to_ctype(dtype_out), dtype_out.itemsize,
             neutral, reduce_expr, map_expr, arguments,
             name, preamble, arg_prep, device, max_group_size)
 
-    inf.program = cl.Program(ctx, inf.source)
-    inf.program.build(options)
-    inf.kernel = getattr(inf.program, name)
+    program = cl.Program(ctx, source)
+    program.build(options)
 
-    inf.arg_types = arguments
-
-    inf.kernel.set_scalar_arg_dtypes(
+    kernel = getattr(program, name)
+    kernel.set_scalar_arg_dtypes(
             [None, np.int64]
-            + get_arg_list_scalar_arg_dtypes(inf.arg_types)
-            + [np.int64]*3 + [np.uint32, np.int64]
+            + get_arg_list_scalar_arg_dtypes(arguments)
+            + [np.int64]*3
+            + [np.uint32, np.int64]
             )
 
-    return inf
+    return _ReductionInfo(
+        context=ctx,
+        source=source,
+        group_size=group_size,
+        program=program,
+        kernel=kernel,
+        arg_types=arguments
+        )
 
 # }}}
 
-- 
GitLab