From b78495f14e0e3219aad5689f6ed9ea23fdb2715f Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Mon, 9 May 2022 10:32:18 -0500 Subject: [PATCH] remove pytools.Record in cl.reduction --- pyopencl/reduction.py | 56 ++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py index e8f69342..f8853feb 100644 --- a/pyopencl/reduction.py +++ b/pyopencl/reduction.py @@ -29,13 +29,16 @@ Based on code/ideas by Mark Harris <mharris@nvidia.com>. None of the original source code remains. """ +from dataclasses import dataclass +from typing import List + +import numpy as np import pyopencl as cl from pyopencl.tools import ( - context_dependent_memoize, - dtype_to_ctype, KernelTemplateBase, + Argument, KernelTemplateBase, + context_dependent_memoize, dtype_to_ctype, _process_code_for_macro) -import numpy as np # {{{ kernel source @@ -120,6 +123,17 @@ KERNEL = r"""//CL// # {{{ internal codegen frontends +@dataclass(frozen=True) +class _ReductionInfo: + context: cl.Context + source: str + group_size: int + + program: cl.Program + kernel: cl.Kernel + arg_types: List[Argument] + + def _get_reduction_source( ctx, out_type, out_type_size, neutral, reduce_expr, map_expr, parsed_args, @@ -175,15 +189,7 @@ def _get_reduction_source( double_support=all(has_double_support(dev) for dev in devices), )) - from pytools import Record - - class ReductionInfo(Record): - pass - - return ReductionInfo( - context=ctx, - source=src, - group_size=group_size) + return src, group_size def get_reduction_kernel(stage, @@ -213,24 +219,30 @@ def get_reduction_kernel(stage, [VectorArg(dtype_out, "pyopencl_reduction_inp")] + arguments) - inf = _get_reduction_source( + source, group_size = _get_reduction_source( ctx, dtype_to_ctype(dtype_out), dtype_out.itemsize, neutral, reduce_expr, map_expr, arguments, name, preamble, arg_prep, device, max_group_size) - inf.program = cl.Program(ctx, inf.source) - inf.program.build(options) - inf.kernel = getattr(inf.program, name) + program = cl.Program(ctx, source) + program.build(options) - inf.arg_types = arguments - - inf.kernel.set_scalar_arg_dtypes( + kernel = getattr(program, name) + kernel.set_scalar_arg_dtypes( [None, np.int64] - + get_arg_list_scalar_arg_dtypes(inf.arg_types) - + [np.int64]*3 + [np.uint32, np.int64] + + get_arg_list_scalar_arg_dtypes(arguments) + + [np.int64]*3 + + [np.uint32, np.int64] ) - return inf + return _ReductionInfo( + context=ctx, + source=source, + group_size=group_size, + program=program, + kernel=kernel, + arg_types=arguments + ) # }}} -- GitLab