From 2c560eea2a50df924d8bc3678482e929a4f91864 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 11 Dec 2012 23:40:55 -0500 Subject: [PATCH] Centralize arg list parsing. --- pyopencl/algorithm.py | 4 +-- pyopencl/elementwise.py | 17 ++++--------- pyopencl/reduction.py | 49 +++++++++++++++++------------------- pyopencl/scan.py | 37 ++++++--------------------- pyopencl/tools.py | 56 +++++++++++++++++++++++++++++++++++------ 5 files changed, 86 insertions(+), 77 deletions(-) diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py index 4e72a6a5..8aef52ec 100644 --- a/pyopencl/algorithm.py +++ b/pyopencl/algorithm.py @@ -356,8 +356,8 @@ class RadixSort(object): # {{{ arg processing - from pyopencl.scan import _parse_args - self.arguments = _parse_args(arguments) + from pyopencl.tools import parse_arg_list + self.arguments = parse_arg_list(arguments) del arguments self.sort_arg_names = sort_arg_names diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py index 75baa22e..159a594a 100644 --- a/pyopencl/elementwise.py +++ b/pyopencl/elementwise.py @@ -100,11 +100,9 @@ def get_elwise_program(context, arguments, operation, def get_elwise_kernel_and_types(context, arguments, operation, name="elwise_kernel", options=[], preamble="", use_range=False, **kwargs): - if isinstance(arguments, str): - from pyopencl.tools import parse_c_arg - parsed_args = [parse_c_arg(arg) for arg in arguments.split(",")] - else: - parsed_args = arguments + + from pyopencl.tools import parse_arg_list + parsed_args = parse_arg_list(arguments) auto_preamble = kwargs.pop("auto_preamble", True) @@ -143,15 +141,10 @@ def get_elwise_kernel_and_types(context, arguments, operation, name=name, options=options, preamble=preamble, use_range=use_range, **kwargs) - scalar_arg_dtypes = [] - for arg in parsed_args: - if isinstance(arg, ScalarArg): - scalar_arg_dtypes.append(arg.dtype) - else: - scalar_arg_dtypes.append(None) + from pyopencl.tools import get_arg_list_scalar_arg_dtypes kernel = getattr(prg, name) - kernel.set_scalar_arg_dtypes(scalar_arg_dtypes) + kernel.set_scalar_arg_dtypes(get_arg_list_scalar_arg_dtypes(parsed_args)) return kernel, parsed_args diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py index 68562b11..c0f6b6d3 100644 --- a/pyopencl/reduction.py +++ b/pyopencl/reduction.py @@ -140,9 +140,9 @@ KERNEL = """//CL// -def get_reduction_source( +def _get_reduction_source( ctx, out_type, out_type_size, - neutral, reduce_expr, map_expr, arguments, + neutral, reduce_expr, map_expr, parsed_args, name="reduce_kernel", preamble="", device=None, max_group_size=None): @@ -198,7 +198,7 @@ def get_reduction_source( from pyopencl.characterize import has_double_support, has_amd_double_support src = str(Template(KERNEL).render( out_type=out_type, - arguments=arguments, + arguments=", ".join(arg.declarator() for arg in parsed_args), group_size=group_size, no_sync_size=no_sync_size, neutral=neutral, @@ -236,33 +236,30 @@ def get_reduction_kernel(stage, map_expr = "in[i]" if stage == 2: - in_arg = "__global const %s *pyopencl_reduction_inp" % out_type + in_arg = "const %s *pyopencl_reduction_inp" % out_type if arguments: arguments = in_arg + ", " + arguments else: arguments = in_arg - inf = get_reduction_source( + from pyopencl.tools import parse_arg_list, get_arg_list_scalar_arg_dtypes + parsed_args = parse_arg_list(arguments) + + inf = _get_reduction_source( ctx, out_type, out_type_size, - neutral, reduce_expr, map_expr, arguments, + neutral, reduce_expr, map_expr, parsed_args, name, preamble, device, max_group_size) inf.program = cl.Program(ctx, inf.source) inf.program.build(options) inf.kernel = getattr(inf.program, name) - from pyopencl.tools import parse_c_arg, ScalarArg - - inf.arg_types = [parse_c_arg(arg) for arg in arguments.split(",")] - scalar_arg_dtypes = [None] - for arg_type in inf.arg_types: - if isinstance(arg_type, ScalarArg): - scalar_arg_dtypes.append(arg_type.dtype) - else: - scalar_arg_dtypes.append(None) - scalar_arg_dtypes.extend([np.uint32]*2) + inf.arg_types = parsed_args - inf.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes) + inf.kernel.set_scalar_arg_dtypes( + [None] + + get_arg_list_scalar_arg_dtypes(inf.arg_types) + + [np.uint32]*2) return inf @@ -390,7 +387,7 @@ def get_sum_kernel(ctx, dtype_out, dtype_in): dtype_out = dtype_in return ReductionKernel(ctx, dtype_out, "0", "a+b", - arguments="__global const %(tp)s *in" + arguments="const %(tp)s *in" % {"tp": dtype_to_ctype(dtype_in)}) @@ -450,8 +447,8 @@ def get_dot_kernel(ctx, dtype_out, dtype_a=None, dtype_b=None): return ReductionKernel(ctx, dtype_out, neutral="0", reduce_expr="a+b", map_expr=map_expr, arguments= - "__global const %(tp_a)s *a, " - "__global const %(tp_b)s *b" % { + "const %(tp_a)s *a, " + "const %(tp_b)s *b" % { "tp_a": dtype_to_ctype(dtype_a), "tp_b": dtype_to_ctype(dtype_b), }) @@ -477,9 +474,9 @@ def get_subset_dot_kernel(ctx, dtype_out, dtype_subset, dtype_a=None, dtype_b=No return ReductionKernel(ctx, dtype_out, neutral="0", reduce_expr="a+b", map_expr="a[lookup_tbl[i]]*b[lookup_tbl[i]]", arguments= - "__global const %(tp_lut)s *lookup_tbl, " - "__global const %(tp_a)s *a, " - "__global const %(tp_b)s *b" % { + "const %(tp_lut)s *lookup_tbl, " + "const %(tp_a)s *a, " + "const %(tp_b)s *b" % { "tp_lut": dtype_to_ctype(dtype_subset), "tp_a": dtype_to_ctype(dtype_a), "tp_b": dtype_to_ctype(dtype_b), @@ -520,7 +517,7 @@ def get_minmax_kernel(ctx, what, dtype): return ReductionKernel(ctx, dtype, neutral=get_minmax_neutral(what, dtype), reduce_expr="%(reduce_expr)s" % {"reduce_expr": reduce_expr}, - arguments="__global const %(tp)s *in" % { + arguments="const %(tp)s *in" % { "tp": dtype_to_ctype(dtype), }, preamble="#define MY_INFINITY (1./0)") @@ -541,8 +538,8 @@ def get_subset_minmax_kernel(ctx, what, dtype, dtype_subset): reduce_expr="%(reduce_expr)s" % {"reduce_expr": reduce_expr}, map_expr="in[lookup_tbl[i]]", arguments= - "__global const %(tp_lut)s *lookup_tbl, " - "__global const %(tp)s *in" % { + "const %(tp_lut)s *lookup_tbl, " + "const %(tp)s *in" % { "tp": dtype_to_ctype(dtype), "tp_lut": dtype_to_ctype(dtype_subset), }, preamble="#define MY_INFINITY (1./0)") diff --git a/pyopencl/scan.py b/pyopencl/scan.py index c7c39de2..3cf67cb8 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -34,7 +34,8 @@ import numpy as np import pyopencl as cl import pyopencl.array from pyopencl.tools import (dtype_to_ctype, bitlog2, - KernelTemplateBase, _process_code_for_macro) + KernelTemplateBase, _process_code_for_macro, + get_arg_list_scalar_arg_dtypes) import pyopencl._mymako as mako from pyopencl._cluda import CLUDA_PREAMBLE @@ -730,31 +731,6 @@ def _round_down_to_power_of_2(val): assert result <= val return result -def _parse_args(arguments): - if isinstance(arguments, str): - arguments = arguments.split(",") - - def parse_single_arg(obj): - if isinstance(obj, str): - from pyopencl.tools import parse_c_arg - return parse_c_arg(obj) - else: - return obj - - return [parse_single_arg(arg) for arg in arguments] - -def _get_scalar_arg_dtypes(arg_types): - result = [] - - from pyopencl.tools import ScalarArg - for arg_type in arg_types: - if isinstance(arg_type, ScalarArg): - result.append(arg_type.dtype) - else: - result.append(None) - - return result - _PREFIX_WORDS = set(""" ldata partial_scan_buffer global scan_offset segment_start_in_k_group carry @@ -982,7 +958,8 @@ class _GenericScanKernelBase(object): self.devices = devices self.options = options - self.parsed_args = _parse_args(arguments) + from pyopencl.tools import parse_arg_list + self.parsed_args = parse_arg_list(arguments) from pyopencl.tools import VectorArg self.first_array_idx = [ i for i, arg in enumerate(self.parsed_args) @@ -1183,7 +1160,7 @@ class GenericScanKernel(_GenericScanKernelBase): final_update_prg, self.name_prefix+"_final_update") update_scalar_arg_dtypes = ( - _get_scalar_arg_dtypes(self.parsed_args) + get_arg_list_scalar_arg_dtypes(self.parsed_args) + [self.index_dtype, self.index_dtype, None, None]) if self.is_segmented: update_scalar_arg_dtypes.append(None) # g_first_segment_start_in_interval @@ -1223,7 +1200,7 @@ class GenericScanKernel(_GenericScanKernelBase): def build_scan_kernel(self, max_wg_size, arguments, input_expr, is_segment_start_expr, input_fetch_exprs, is_first_level, store_segment_start_flags, k_group_size): - scalar_arg_dtypes = _get_scalar_arg_dtypes(arguments) + scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments) # Empirically found on Nv hardware: no need to be bigger than this size wg_size = _round_down_to_power_of_2( @@ -1437,7 +1414,7 @@ class GenericDebugScanKernel(_GenericScanKernelBase): self.kernel = getattr( scan_prg, self.name_prefix+"_debug_scan") scalar_arg_dtypes = ( - _get_scalar_arg_dtypes(self.parsed_args) + get_arg_list_scalar_arg_dtypes(self.parsed_args) + [self.index_dtype]) self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes) diff --git a/pyopencl/tools.py b/pyopencl/tools.py index 22b782e6..ad26be72 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -236,7 +236,10 @@ def pytest_generate_tests_for_pyopencl(metafunc): # {{{ C argument lists -class Argument: +class Argument(object): + pass + +class DtypedArgument(Argument): def __init__(self, dtype, name): self.dtype = np.dtype(dtype) self.name = name @@ -247,27 +250,66 @@ class Argument: self.name, self.dtype) -class VectorArg(Argument): +class VectorArg(DtypedArgument): def declarator(self): return "__global %s *%s" % (dtype_to_ctype(self.dtype), self.name) -class ScalarArg(Argument): +class ScalarArg(DtypedArgument): def declarator(self): return "%s %s" % (dtype_to_ctype(self.dtype), self.name) +class OtherArgument(Argument): + def __init__(self, declarator): + self.declarator = declarator + + def declarator(self): + return self.declarator + def parse_c_arg(c_arg): - c_arg = (c_arg - .replace("__global", "") - .replace("__local", "") - .replace("__constant", "")) + for aspace in ["__local", "__constant"]: + if aspace in c_arg: + raise RuntimeError("cannot deal with local or constant " + "OpenCL address spaces in C argument lists ") + + c_arg = c_arg.replace("__global", "") from pyopencl.compyte.dtypes import parse_c_arg_backend return parse_c_arg_backend(c_arg, ScalarArg, VectorArg) +def parse_arg_list(arguments): + """Parse a list of kernel arguments. *arguments* may be a comma-separate list + of C declarators in a string, a list of strings representing C declarators, + or :class:`Argument` objects. + """ + + if isinstance(arguments, str): + arguments = arguments.split(",") + + def parse_single_arg(obj): + if isinstance(obj, str): + from pyopencl.tools import parse_c_arg + return parse_c_arg(obj) + else: + return obj + + return [parse_single_arg(arg) for arg in arguments] + +def get_arg_list_scalar_arg_dtypes(arg_types): + result = [] + + from pyopencl.tools import ScalarArg + for arg_type in arg_types: + if isinstance(arg_type, ScalarArg): + result.append(arg_type.dtype) + else: + result.append(None) + + return result + # }}} -- GitLab