From 2c560eea2a50df924d8bc3678482e929a4f91864 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 11 Dec 2012 23:40:55 -0500
Subject: [PATCH] Centralize arg list parsing.

---
 pyopencl/algorithm.py   |  4 +--
 pyopencl/elementwise.py | 17 ++++---------
 pyopencl/reduction.py   | 49 +++++++++++++++++-------------------
 pyopencl/scan.py        | 37 ++++++---------------------
 pyopencl/tools.py       | 56 +++++++++++++++++++++++++++++++++++------
 5 files changed, 86 insertions(+), 77 deletions(-)

diff --git a/pyopencl/algorithm.py b/pyopencl/algorithm.py
index 4e72a6a5..8aef52ec 100644
--- a/pyopencl/algorithm.py
+++ b/pyopencl/algorithm.py
@@ -356,8 +356,8 @@ class RadixSort(object):
 
         # {{{ arg processing
 
-        from pyopencl.scan import _parse_args
-        self.arguments = _parse_args(arguments)
+        from pyopencl.tools import parse_arg_list
+        self.arguments = parse_arg_list(arguments)
         del arguments
 
         self.sort_arg_names = sort_arg_names
diff --git a/pyopencl/elementwise.py b/pyopencl/elementwise.py
index 75baa22e..159a594a 100644
--- a/pyopencl/elementwise.py
+++ b/pyopencl/elementwise.py
@@ -100,11 +100,9 @@ def get_elwise_program(context, arguments, operation,
 def get_elwise_kernel_and_types(context, arguments, operation,
         name="elwise_kernel", options=[], preamble="", use_range=False,
         **kwargs):
-    if isinstance(arguments, str):
-        from pyopencl.tools import parse_c_arg
-        parsed_args = [parse_c_arg(arg) for arg in arguments.split(",")]
-    else:
-        parsed_args = arguments
+
+    from pyopencl.tools import parse_arg_list
+    parsed_args = parse_arg_list(arguments)
 
     auto_preamble = kwargs.pop("auto_preamble", True)
 
@@ -143,15 +141,10 @@ def get_elwise_kernel_and_types(context, arguments, operation,
         name=name, options=options, preamble=preamble,
         use_range=use_range, **kwargs)
 
-    scalar_arg_dtypes = []
-    for arg in parsed_args:
-        if isinstance(arg, ScalarArg):
-            scalar_arg_dtypes.append(arg.dtype)
-        else:
-            scalar_arg_dtypes.append(None)
+    from pyopencl.tools import get_arg_list_scalar_arg_dtypes
 
     kernel = getattr(prg, name)
-    kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)
+    kernel.set_scalar_arg_dtypes(get_arg_list_scalar_arg_dtypes(parsed_args))
 
     return kernel, parsed_args
 
diff --git a/pyopencl/reduction.py b/pyopencl/reduction.py
index 68562b11..c0f6b6d3 100644
--- a/pyopencl/reduction.py
+++ b/pyopencl/reduction.py
@@ -140,9 +140,9 @@ KERNEL = """//CL//
 
 
 
-def  get_reduction_source(
+def  _get_reduction_source(
          ctx, out_type, out_type_size,
-         neutral, reduce_expr, map_expr, arguments,
+         neutral, reduce_expr, map_expr, parsed_args,
          name="reduce_kernel", preamble="",
          device=None, max_group_size=None):
 
@@ -198,7 +198,7 @@ def  get_reduction_source(
     from pyopencl.characterize import has_double_support, has_amd_double_support
     src = str(Template(KERNEL).render(
         out_type=out_type,
-        arguments=arguments,
+        arguments=", ".join(arg.declarator() for arg in parsed_args),
         group_size=group_size,
         no_sync_size=no_sync_size,
         neutral=neutral,
@@ -236,33 +236,30 @@ def get_reduction_kernel(stage,
             map_expr = "in[i]"
 
     if stage == 2:
-        in_arg = "__global const %s *pyopencl_reduction_inp" % out_type
+        in_arg = "const %s *pyopencl_reduction_inp" % out_type
         if arguments:
             arguments = in_arg + ", " + arguments
         else:
             arguments = in_arg
 
-    inf = get_reduction_source(
+    from pyopencl.tools import parse_arg_list, get_arg_list_scalar_arg_dtypes
+    parsed_args = parse_arg_list(arguments)
+
+    inf = _get_reduction_source(
             ctx, out_type, out_type_size,
-            neutral, reduce_expr, map_expr, arguments,
+            neutral, reduce_expr, map_expr, parsed_args,
             name, preamble, device, max_group_size)
 
     inf.program = cl.Program(ctx, inf.source)
     inf.program.build(options)
     inf.kernel = getattr(inf.program, name)
 
-    from pyopencl.tools import parse_c_arg, ScalarArg
-
-    inf.arg_types = [parse_c_arg(arg) for arg in arguments.split(",")]
-    scalar_arg_dtypes = [None]
-    for arg_type in inf.arg_types:
-        if isinstance(arg_type, ScalarArg):
-            scalar_arg_dtypes.append(arg_type.dtype)
-        else:
-            scalar_arg_dtypes.append(None)
-    scalar_arg_dtypes.extend([np.uint32]*2)
+    inf.arg_types = parsed_args
 
-    inf.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)
+    inf.kernel.set_scalar_arg_dtypes(
+            [None]
+            + get_arg_list_scalar_arg_dtypes(inf.arg_types)
+            + [np.uint32]*2)
 
     return inf
 
@@ -390,7 +387,7 @@ def get_sum_kernel(ctx, dtype_out, dtype_in):
         dtype_out = dtype_in
 
     return ReductionKernel(ctx, dtype_out, "0", "a+b",
-            arguments="__global const %(tp)s *in"
+            arguments="const %(tp)s *in"
             % {"tp": dtype_to_ctype(dtype_in)})
 
 
@@ -450,8 +447,8 @@ def get_dot_kernel(ctx, dtype_out, dtype_a=None, dtype_b=None):
     return ReductionKernel(ctx, dtype_out, neutral="0",
             reduce_expr="a+b", map_expr=map_expr,
             arguments=
-            "__global const %(tp_a)s *a, "
-            "__global const %(tp_b)s *b" % {
+            "const %(tp_a)s *a, "
+            "const %(tp_b)s *b" % {
                 "tp_a": dtype_to_ctype(dtype_a),
                 "tp_b": dtype_to_ctype(dtype_b),
                 })
@@ -477,9 +474,9 @@ def get_subset_dot_kernel(ctx, dtype_out, dtype_subset, dtype_a=None, dtype_b=No
     return ReductionKernel(ctx, dtype_out, neutral="0",
             reduce_expr="a+b", map_expr="a[lookup_tbl[i]]*b[lookup_tbl[i]]",
             arguments=
-            "__global const %(tp_lut)s *lookup_tbl, "
-            "__global const %(tp_a)s *a, "
-            "__global const %(tp_b)s *b" % {
+            "const %(tp_lut)s *lookup_tbl, "
+            "const %(tp_a)s *a, "
+            "const %(tp_b)s *b" % {
             "tp_lut": dtype_to_ctype(dtype_subset),
             "tp_a": dtype_to_ctype(dtype_a),
             "tp_b": dtype_to_ctype(dtype_b),
@@ -520,7 +517,7 @@ def get_minmax_kernel(ctx, what, dtype):
     return ReductionKernel(ctx, dtype,
             neutral=get_minmax_neutral(what, dtype),
             reduce_expr="%(reduce_expr)s" % {"reduce_expr": reduce_expr},
-            arguments="__global const %(tp)s *in" % {
+            arguments="const %(tp)s *in" % {
                 "tp": dtype_to_ctype(dtype),
                 }, preamble="#define MY_INFINITY (1./0)")
 
@@ -541,8 +538,8 @@ def get_subset_minmax_kernel(ctx, what, dtype, dtype_subset):
             reduce_expr="%(reduce_expr)s" % {"reduce_expr": reduce_expr},
             map_expr="in[lookup_tbl[i]]",
             arguments=
-            "__global const %(tp_lut)s *lookup_tbl, "
-            "__global const %(tp)s *in"  % {
+            "const %(tp_lut)s *lookup_tbl, "
+            "const %(tp)s *in"  % {
             "tp": dtype_to_ctype(dtype),
             "tp_lut": dtype_to_ctype(dtype_subset),
             }, preamble="#define MY_INFINITY (1./0)")
diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index c7c39de2..3cf67cb8 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -34,7 +34,8 @@ import numpy as np
 import pyopencl as cl
 import pyopencl.array
 from pyopencl.tools import (dtype_to_ctype, bitlog2,
-        KernelTemplateBase, _process_code_for_macro)
+        KernelTemplateBase, _process_code_for_macro,
+        get_arg_list_scalar_arg_dtypes)
 import pyopencl._mymako as mako
 from pyopencl._cluda import CLUDA_PREAMBLE
 
@@ -730,31 +731,6 @@ def _round_down_to_power_of_2(val):
     assert result <= val
     return result
 
-def _parse_args(arguments):
-    if isinstance(arguments, str):
-        arguments = arguments.split(",")
-
-    def parse_single_arg(obj):
-        if isinstance(obj, str):
-            from pyopencl.tools import parse_c_arg
-            return parse_c_arg(obj)
-        else:
-            return obj
-
-    return [parse_single_arg(arg) for arg in arguments]
-
-def _get_scalar_arg_dtypes(arg_types):
-    result = []
-
-    from pyopencl.tools import ScalarArg
-    for arg_type in arg_types:
-        if isinstance(arg_type, ScalarArg):
-            result.append(arg_type.dtype)
-        else:
-            result.append(None)
-
-    return result
-
 _PREFIX_WORDS = set("""
         ldata partial_scan_buffer global scan_offset
         segment_start_in_k_group carry
@@ -982,7 +958,8 @@ class _GenericScanKernelBase(object):
         self.devices = devices
         self.options = options
 
-        self.parsed_args = _parse_args(arguments)
+        from pyopencl.tools import parse_arg_list
+        self.parsed_args = parse_arg_list(arguments)
         from pyopencl.tools import VectorArg
         self.first_array_idx = [
                 i for i, arg in enumerate(self.parsed_args)
@@ -1183,7 +1160,7 @@ class GenericScanKernel(_GenericScanKernelBase):
                 final_update_prg,
                 self.name_prefix+"_final_update")
         update_scalar_arg_dtypes = (
-                _get_scalar_arg_dtypes(self.parsed_args)
+                get_arg_list_scalar_arg_dtypes(self.parsed_args)
                 + [self.index_dtype, self.index_dtype, None, None])
         if self.is_segmented:
             update_scalar_arg_dtypes.append(None) # g_first_segment_start_in_interval
@@ -1223,7 +1200,7 @@ class GenericScanKernel(_GenericScanKernelBase):
     def build_scan_kernel(self, max_wg_size, arguments, input_expr,
             is_segment_start_expr, input_fetch_exprs, is_first_level,
             store_segment_start_flags, k_group_size):
-        scalar_arg_dtypes = _get_scalar_arg_dtypes(arguments)
+        scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments)
 
         # Empirically found on Nv hardware: no need to be bigger than this size
         wg_size = _round_down_to_power_of_2(
@@ -1437,7 +1414,7 @@ class GenericDebugScanKernel(_GenericScanKernelBase):
         self.kernel = getattr(
                 scan_prg, self.name_prefix+"_debug_scan")
         scalar_arg_dtypes = (
-                _get_scalar_arg_dtypes(self.parsed_args)
+                get_arg_list_scalar_arg_dtypes(self.parsed_args)
                 + [self.index_dtype])
         self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)
 
diff --git a/pyopencl/tools.py b/pyopencl/tools.py
index 22b782e6..ad26be72 100644
--- a/pyopencl/tools.py
+++ b/pyopencl/tools.py
@@ -236,7 +236,10 @@ def pytest_generate_tests_for_pyopencl(metafunc):
 
 # {{{ C argument lists
 
-class Argument:
+class Argument(object):
+    pass
+
+class DtypedArgument(Argument):
     def __init__(self, dtype, name):
         self.dtype = np.dtype(dtype)
         self.name = name
@@ -247,27 +250,66 @@ class Argument:
                 self.name,
                 self.dtype)
 
-class VectorArg(Argument):
+class VectorArg(DtypedArgument):
     def declarator(self):
         return "__global %s *%s" % (dtype_to_ctype(self.dtype), self.name)
 
-class ScalarArg(Argument):
+class ScalarArg(DtypedArgument):
     def declarator(self):
         return "%s %s" % (dtype_to_ctype(self.dtype), self.name)
 
+class OtherArgument(Argument):
+    def __init__(self, declarator):
+        self.declarator = declarator
+
+    def declarator(self):
+        return self.declarator
+
 
 
 
 
 def parse_c_arg(c_arg):
-    c_arg = (c_arg
-            .replace("__global", "")
-            .replace("__local", "")
-            .replace("__constant", ""))
+    for aspace in ["__local", "__constant"]:
+        if aspace in c_arg:
+            raise RuntimeError("cannot deal with local or constant "
+                    "OpenCL address spaces in C argument lists ")
+
+    c_arg = c_arg.replace("__global", "")
 
     from pyopencl.compyte.dtypes import parse_c_arg_backend
     return parse_c_arg_backend(c_arg, ScalarArg, VectorArg)
 
+def parse_arg_list(arguments):
+    """Parse a list of kernel arguments. *arguments* may be a comma-separate list
+    of C declarators in a string, a list of strings representing C declarators,
+    or :class:`Argument` objects.
+    """
+
+    if isinstance(arguments, str):
+        arguments = arguments.split(",")
+
+    def parse_single_arg(obj):
+        if isinstance(obj, str):
+            from pyopencl.tools import parse_c_arg
+            return parse_c_arg(obj)
+        else:
+            return obj
+
+    return [parse_single_arg(arg) for arg in arguments]
+
+def get_arg_list_scalar_arg_dtypes(arg_types):
+    result = []
+
+    from pyopencl.tools import ScalarArg
+    for arg_type in arg_types:
+        if isinstance(arg_type, ScalarArg):
+            result.append(arg_type.dtype)
+        else:
+            result.append(None)
+
+    return result
+
 # }}}
 
 
-- 
GitLab