From ae9253c6106088d881cca602ed727671ab20aa7f Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 28 Jul 2012 16:13:50 -0400
Subject: [PATCH] Let scan accept list of VectorArg/ScalarArg objects.

---
 pyopencl/scan.py | 41 +++++++++++++++++++++++++----------------
 1 file changed, 25 insertions(+), 16 deletions(-)

diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 5ac759f5..d4744e5c 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -637,8 +637,7 @@ void ${name_prefix}_final_update(
 # {{{ helpers
 
 def _process_code_for_macro(code):
-    if code.startswith("//CL//"):
-        code = code[6:]
+    code = code.replace("//CL//", "\n")
 
     if "//" in code:
         raise RuntimeError("end-of-line comments ('//') may not be used in "
@@ -655,8 +654,17 @@ def _round_down_to_power_of_2(val):
     return result
 
 def _parse_args(arguments):
-    from pyopencl.tools import parse_c_arg
-    return [parse_c_arg(arg) for arg in arguments.split(",")]
+    if isinstance(arguments, str):
+        arguments = arguments.split(",")
+
+    def parse_single_arg(obj):
+        if isinstance(obj, str):
+            from pyopencl.tools import parse_c_arg
+            return parse_c_arg(obj)
+        else:
+            return obj
+
+    return [parse_single_arg(arg) for arg in arguments]
 
 def _get_scalar_arg_dtypes(arg_types):
     result = []
@@ -863,8 +871,7 @@ class _GenericScanKernelBase(object):
         self.devices = devices
         self.options = options
 
-        self.arguments = arguments
-        self.parsed_args = _parse_args(self.arguments)
+        self.parsed_args = _parse_args(arguments)
         from pyopencl.tools import VectorArg
         self.first_array_idx = [
                 i for i, arg in enumerate(self.parsed_args)
@@ -933,7 +940,7 @@ class GenericScanKernel(_GenericScanKernelBase):
 
         while True:
             candidate_scan_info = self.build_scan_kernel(
-                    max_scan_wg_size, self.arguments, self.input_expr,
+                    max_scan_wg_size, self.parsed_args, self.input_expr,
                     self.is_segment_start_expr,
                     input_fetch_exprs=self.input_fetch_exprs,
                     is_first_level=True,
@@ -964,13 +971,15 @@ class GenericScanKernel(_GenericScanKernelBase):
 
         # {{{ build second-level scan
 
-        second_level_arguments = self.arguments.split(",") + [
-                "__global %s *interval_sums" % dtype_to_ctype(self.dtype)]
+        from pyopencl.tools import VectorArg
+        second_level_arguments = self.parsed_args + [
+                VectorArg(self.dtype, "interval_sums")]
+
         second_level_build_kwargs = {}
         if self.is_segmented:
             second_level_arguments.append(
-                    "__global %s *g_first_segment_start_in_interval_input"
-                    % dtype_to_ctype(self.index_dtype))
+                    VectorArg(self.index_dtype,
+                        "g_first_segment_start_in_interval_input"))
 
             # is_segment_start_expr answers the question "should previous sums
             # spill over into this item". And since g_first_segment_start_in_interval_input
@@ -983,7 +992,7 @@ class GenericScanKernel(_GenericScanKernelBase):
 
         self.second_level_scan_info = self.build_scan_kernel(
                 max_scan_wg_size,
-                arguments=", ".join(second_level_arguments),
+                arguments=second_level_arguments,
                 input_expr="interval_sums[i]",
                 input_fetch_exprs=[],
                 is_first_level=False,
@@ -1005,8 +1014,8 @@ class GenericScanKernel(_GenericScanKernelBase):
         final_update_tpl = _make_template(UPDATE_SOURCE)
         final_update_src = str(final_update_tpl.render(
             wg_size=self.update_wg_size,
-            output_statement=_process_code_for_macro(self.output_statement),
-            argument_signature=_process_code_for_macro(self.arguments),
+            output_statement=self.output_statement,
+            argument_signature=", ".join(arg.declarator() for arg in self.parsed_args),
             is_segment_start_expr=self.is_segment_start_expr,
             input_expr=_process_code_for_macro(self.input_expr),
             use_lookbehind_update=use_lookbehind_update,
@@ -1032,7 +1041,7 @@ class GenericScanKernel(_GenericScanKernelBase):
     def build_scan_kernel(self, max_wg_size, arguments, input_expr,
             is_segment_start_expr, input_fetch_exprs, is_first_level,
             store_segment_start_flags):
-        scalar_arg_dtypes = _get_scalar_arg_dtypes(_parse_args(arguments))
+        scalar_arg_dtypes = _get_scalar_arg_dtypes(arguments)
 
         # Thrust says that 128 is big enough for GT200
         wg_size = _round_down_to_power_of_2(
@@ -1052,7 +1061,7 @@ class GenericScanKernel(_GenericScanKernelBase):
             wg_size=wg_size,
             input_expr=input_expr,
             k_group_size=k_group_size,
-            argument_signature=_process_code_for_macro(arguments),
+            argument_signature=", ".join(arg.declarator() for arg in arguments),
             is_segment_start_expr=is_segment_start_expr,
             input_fetch_exprs=input_fetch_exprs,
             is_first_level=is_first_level,
-- 
GitLab