From ae9253c6106088d881cca602ed727671ab20aa7f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sat, 28 Jul 2012 16:13:50 -0400 Subject: [PATCH] Let scan accept list of VectorArg/ScalarArg objects. --- pyopencl/scan.py | 41 +++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 5ac759f5..d4744e5c 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -637,8 +637,7 @@ void ${name_prefix}_final_update( # {{{ helpers def _process_code_for_macro(code): - if code.startswith("//CL//"): - code = code[6:] + code = code.replace("//CL//", "\n") if "//" in code: raise RuntimeError("end-of-line comments ('//') may not be used in " @@ -655,8 +654,17 @@ def _round_down_to_power_of_2(val): return result def _parse_args(arguments): - from pyopencl.tools import parse_c_arg - return [parse_c_arg(arg) for arg in arguments.split(",")] + if isinstance(arguments, str): + arguments = arguments.split(",") + + def parse_single_arg(obj): + if isinstance(obj, str): + from pyopencl.tools import parse_c_arg + return parse_c_arg(obj) + else: + return obj + + return [parse_single_arg(arg) for arg in arguments] def _get_scalar_arg_dtypes(arg_types): result = [] @@ -863,8 +871,7 @@ class _GenericScanKernelBase(object): self.devices = devices self.options = options - self.arguments = arguments - self.parsed_args = _parse_args(self.arguments) + self.parsed_args = _parse_args(arguments) from pyopencl.tools import VectorArg self.first_array_idx = [ i for i, arg in enumerate(self.parsed_args) @@ -933,7 +940,7 @@ class GenericScanKernel(_GenericScanKernelBase): while True: candidate_scan_info = self.build_scan_kernel( - max_scan_wg_size, self.arguments, self.input_expr, + max_scan_wg_size, self.parsed_args, self.input_expr, self.is_segment_start_expr, input_fetch_exprs=self.input_fetch_exprs, is_first_level=True, @@ -964,13 +971,15 @@ class GenericScanKernel(_GenericScanKernelBase): # {{{ build second-level scan - second_level_arguments = self.arguments.split(",") + [ - "__global %s *interval_sums" % dtype_to_ctype(self.dtype)] + from pyopencl.tools import VectorArg + second_level_arguments = self.parsed_args + [ + VectorArg(self.dtype, "interval_sums")] + second_level_build_kwargs = {} if self.is_segmented: second_level_arguments.append( - "__global %s *g_first_segment_start_in_interval_input" - % dtype_to_ctype(self.index_dtype)) + VectorArg(self.index_dtype, + "g_first_segment_start_in_interval_input")) # is_segment_start_expr answers the question "should previous sums # spill over into this item". And since g_first_segment_start_in_interval_input @@ -983,7 +992,7 @@ class GenericScanKernel(_GenericScanKernelBase): self.second_level_scan_info = self.build_scan_kernel( max_scan_wg_size, - arguments=", ".join(second_level_arguments), + arguments=second_level_arguments, input_expr="interval_sums[i]", input_fetch_exprs=[], is_first_level=False, @@ -1005,8 +1014,8 @@ class GenericScanKernel(_GenericScanKernelBase): final_update_tpl = _make_template(UPDATE_SOURCE) final_update_src = str(final_update_tpl.render( wg_size=self.update_wg_size, - output_statement=_process_code_for_macro(self.output_statement), - argument_signature=_process_code_for_macro(self.arguments), + output_statement=self.output_statement, + argument_signature=", ".join(arg.declarator() for arg in self.parsed_args), is_segment_start_expr=self.is_segment_start_expr, input_expr=_process_code_for_macro(self.input_expr), use_lookbehind_update=use_lookbehind_update, @@ -1032,7 +1041,7 @@ class GenericScanKernel(_GenericScanKernelBase): def build_scan_kernel(self, max_wg_size, arguments, input_expr, is_segment_start_expr, input_fetch_exprs, is_first_level, store_segment_start_flags): - scalar_arg_dtypes = _get_scalar_arg_dtypes(_parse_args(arguments)) + scalar_arg_dtypes = _get_scalar_arg_dtypes(arguments) # Thrust says that 128 is big enough for GT200 wg_size = _round_down_to_power_of_2( @@ -1052,7 +1061,7 @@ class GenericScanKernel(_GenericScanKernelBase): wg_size=wg_size, input_expr=input_expr, k_group_size=k_group_size, - argument_signature=_process_code_for_macro(arguments), + argument_signature=", ".join(arg.declarator() for arg in arguments), is_segment_start_expr=is_segment_start_expr, input_fetch_exprs=input_fetch_exprs, is_first_level=is_first_level, -- GitLab