From 2f50d72ef1abb6c53e6c7250ac10d962a64c4bcc Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Fri, 27 Jul 2012 01:49:35 -0400
Subject: [PATCH] Introduce scan debug kernel.

---
 doc/source/array.rst |  10 +++
 pyopencl/scan.py     | 164 +++++++++++++++++++++++++++++++++++--------
 test/test_array.py   |  36 ++++++----
 3 files changed, 165 insertions(+), 45 deletions(-)

diff --git a/doc/source/array.rst b/doc/source/array.rst
index 8546747e..560c68ae 100644
--- a/doc/source/array.rst
+++ b/doc/source/array.rst
@@ -765,6 +765,16 @@ Making Custom Scan Kernels
         *queue* and *allocator* default to the ones provided on the first
         :class:`pyopencl.array.Array` in *args*.
 
+Debugging aids
+~~~~~~~~~~~~~~
+
+.. class:: GenericDebugScanKernel
+
+    Performs the same function and has the same interface as
+    :class:`GenericScanKernel`, but uses a dead-simple, sequential scan.  Works
+    best on CPU platforms, and helps isolate bugs in scans by removing the
+    potential for issues originating in parallel execution.
+
 .. _predefined-scans:
 
 Pre-defined higher-level operations
diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index efd2a1ae..cf70cee6 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -773,12 +773,8 @@ class _ScanKernelInfo(Record):
 
 # }}}
 
-class GenericScanKernel(object):
-    """Generates and executes code that performs prefix sums ("scans") on
-    arbitrary types, with many possible tweaks.
-    """
-
-    # {{{ constructor
+class _GenericScanKernelBase(object):
+    # {{{ constructor, argument processing
 
     def __init__(self, ctx, dtype,
             arguments, input_expr, scan_expr, neutral, output_statement,
@@ -848,9 +844,6 @@ class GenericScanKernel(object):
         being processed in the scan.
         """
 
-        if isinstance(self, ExclusiveScanKernel) and neutral is None:
-            raise ValueError("neutral element is required for exclusive scan")
-
         self.context = ctx
         dtype = self.dtype = np.dtype(dtype)
 
@@ -877,21 +870,27 @@ class GenericScanKernel(object):
                 i for i, arg in enumerate(self.parsed_args)
                 if isinstance(arg, VectorArg)][0]
 
+        self.input_expr = input_expr
+
+        self.is_segment_start_expr = is_segment_start_expr
         self.is_segmented = is_segment_start_expr is not None
         if self.is_segmented:
             is_segment_start_expr = _process_code_for_macro(is_segment_start_expr)
 
-        use_lookbehind_update = "prev_item" in output_statement
-        self.store_segment_start_flags = self.is_segmented and use_lookbehind_update
+        self.output_statement = output_statement
 
         for name, arg_name, ife_offset in input_fetch_exprs:
             if ife_offset not in [0, -1]:
                 raise RuntimeError("input_fetch_expr offsets must either be 0 or -1")
+        self.input_fetch_exprs = input_fetch_exprs
 
         arg_ctypes = {}
         for arg in self.parsed_args:
             arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype)
 
+        self.options = options
+        self.name_prefix = name_prefix
+
         # {{{ set up shared code dict
 
         from pytools import all
@@ -913,6 +912,19 @@ class GenericScanKernel(object):
 
         # }}}
 
+        self.finish_setup()
+
+    # }}}
+
+class GenericScanKernel(_GenericScanKernelBase):
+    """Generates and executes code that performs prefix sums ("scans") on
+    arbitrary types, with many possible tweaks.
+    """
+
+    def finish_setup(self):
+        use_lookbehind_update = "prev_item" in self.output_statement
+        self.store_segment_start_flags = self.is_segmented and use_lookbehind_update
+
         # {{{ loop to find usable workgroup size, build first-level scan
 
         trip_count = 0
@@ -921,9 +933,9 @@ class GenericScanKernel(object):
 
         while True:
             candidate_scan_info = self.build_scan_kernel(
-                    max_scan_wg_size, arguments, input_expr,
-                    is_segment_start_expr,
-                    input_fetch_exprs=input_fetch_exprs,
+                    max_scan_wg_size, self.arguments, self.input_expr,
+                    self.is_segment_start_expr,
+                    input_fetch_exprs=self.input_fetch_exprs,
                     is_first_level=True,
                     store_segment_start_flags=self.store_segment_start_flags)
 
@@ -953,7 +965,7 @@ class GenericScanKernel(object):
         # {{{ build second-level scan
 
         second_level_arguments = self.arguments.split(",") + [
-                "__global %s *interval_sums" % dtype_to_ctype(dtype)]
+                "__global %s *interval_sums" % dtype_to_ctype(self.dtype)]
         second_level_build_kwargs = {}
         if self.is_segmented:
             second_level_arguments.append(
@@ -993,19 +1005,17 @@ class GenericScanKernel(object):
         final_update_tpl = _make_template(UPDATE_SOURCE)
         final_update_src = str(final_update_tpl.render(
             wg_size=self.update_wg_size,
-            output_statement=_process_code_for_macro(output_statement),
-            argument_signature=_process_code_for_macro(arguments),
-            is_segment_start_expr=is_segment_start_expr,
-            input_expr=_process_code_for_macro(input_expr),
+            output_statement=_process_code_for_macro(self.output_statement),
+            argument_signature=_process_code_for_macro(self.arguments),
+            is_segment_start_expr=self.is_segment_start_expr,
+            input_expr=_process_code_for_macro(self.input_expr),
             use_lookbehind_update=use_lookbehind_update,
             **self.code_variables))
 
-        #with open("update.cl", "wt") as f: f.write(final_update_src)
-
-        final_update_prg = cl.Program(self.context, final_update_src).build(options)
+        final_update_prg = cl.Program(self.context, final_update_src).build(self.options)
         self.final_update_knl = getattr(
                 final_update_prg,
-                name_prefix+"_final_update")
+                self.name_prefix+"_final_update")
         update_scalar_arg_dtypes = (
                 _get_scalar_arg_dtypes(self.parsed_args)
                 + [self.index_dtype, self.index_dtype, None, None])
@@ -1017,8 +1027,6 @@ class GenericScanKernel(object):
 
         # }}}
 
-    # }}}
-
     # {{{ scan kernel build
 
     def build_scan_kernel(self, max_wg_size, arguments, input_expr,
@@ -1051,8 +1059,6 @@ class GenericScanKernel(object):
             store_segment_start_flags=store_segment_start_flags,
             **self.code_variables))
 
-        #with open("scan-lev%d.cl" % (1 if is_first_level else 2), "wt") as f: f.write(scan_src)
-
         prg = cl.Program(self.context, scan_src).build(self.options)
 
         knl = getattr(
@@ -1186,9 +1192,107 @@ class GenericScanKernel(object):
 
 # }}}
 
+# {{{ debug kernel
+
+DEBUG_SCAN_TEMPLATE = SHARED_PREAMBLE + """//CL//
+
+#define OUTPUT_STMT(i, prev_item, item) { ${output_statement}; }
+
+KERNEL
+REQD_WG_SIZE(1, 1, 1)
+void ${name_prefix}_debug_scan(
+    ${argument_signature},
+    const index_type N)
+{
+    scan_type item = ${neutral};
+    scan_type prev_item;
+
+    for (index_type i = 0; i < N; ++i)
+    {
+        %for name, arg_name, ife_offset in input_fetch_exprs:
+            ${arg_ctypes[arg_name]} ${name};
+            %if ife_offset < 0:
+                if (i+${ife_offset} >= 0)
+                    ${name} = ${arg_name}[i+offset];
+            %else:
+                ${name} = ${arg_name}[i];
+            %endif
+        %endfor
+
+        scan_type my_val = INPUT_EXPR(i);
+
+        prev_item = item;
+        %if is_segmented:
+            {
+                bool is_seg_start = IS_SEG_START(i, my_val);
+                if (is_seg_start)
+                    prev_item = ${neutral};
+            }
+        %endif
+        item = SCAN_EXPR(prev_item, my_val);
+
+        {
+            OUTPUT_STMT(i, prev_item, item);
+        }
+    }
+}
+"""
+
+class GenericDebugScanKernel(_GenericScanKernelBase):
+    def finish_setup(self):
+        scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE)
+        scan_src = str(scan_tpl.render(
+            output_statement=_process_code_for_macro(self.output_statement),
+            argument_signature=_process_code_for_macro(self.arguments),
+            is_segment_start_expr=self.is_segment_start_expr,
+            input_expr=_process_code_for_macro(self.input_expr),
+            input_fetch_exprs=self.input_fetch_exprs,
+            wg_size=1,
+            **self.code_variables))
+
+        scan_prg = cl.Program(self.context, scan_src).build(self.options)
+        self.kernel = getattr(
+                scan_prg, self.name_prefix+"_debug_scan")
+        scalar_arg_dtypes = (
+                _get_scalar_arg_dtypes(self.parsed_args)
+                + [self.index_dtype])
+        self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)
+
+        # }}}
+
+    def __call__(self, *args, **kwargs):
+        # {{{ argument processing
+
+        allocator = kwargs.get("allocator")
+        queue = kwargs.get("queue")
+
+        if len(args) != len(self.parsed_args):
+            raise TypeError("invalid number of arguments in "
+                    "custom-arguments mode")
+
+        first_array = args[self.first_array_idx]
+        allocator = allocator or first_array.allocator
+        queue = queue or first_array.queue
+
+        n, = first_array.shape
+
+        data_args = []
+        from pyopencl.tools import VectorArg
+        for arg_descr, arg_val in zip(self.parsed_args, args):
+            if isinstance(arg_descr, VectorArg):
+                data_args.append(arg_val.data)
+            else:
+                data_args.append(arg_val)
+
+        # }}}
+
+        self.kernel(queue, (1,), (1,), *(data_args + [n]))
+
+# }}}
+
 # {{{ compatibility interface
 
-class _ScanKernelBase(GenericScanKernel):
+class _LegacyScanKernelBase(GenericScanKernel): # FIXME
     def __init__(self, ctx, dtype,
             scan_expr, neutral=None,
             name_prefix="scan", options=[], preamble="", devices=None):
@@ -1230,10 +1334,10 @@ class _ScanKernelBase(GenericScanKernel):
 
         return output_ary
 
-class InclusiveScanKernel(_ScanKernelBase):
+class InclusiveScanKernel(_LegacyScanKernelBase):
     ary_output_statement = "output_ary[i] = item;"
 
-class ExclusiveScanKernel(_ScanKernelBase):
+class ExclusiveScanKernel(_LegacyScanKernelBase):
     ary_output_statement = "output_ary[i] = prev_item;"
 
 # }}}
diff --git a/test/test_array.py b/test/test_array.py
index 44bdeb31..bef04fb9 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -913,22 +913,28 @@ def test_index_preservation(ctx_factory):
     context = ctx_factory()
     queue = cl.CommandQueue(context)
 
-    for n in scan_test_counts:
+    from pyopencl.scan import GenericScanKernel, GenericDebugScanKernel
+    classes = [GenericScanKernel]
 
-        from pyopencl.scan import GenericScanKernel
-        knl = GenericScanKernel(
-                context, np.int32,
-                arguments="__global int *out",
-                input_expr="i",
-                scan_expr="b", neutral="0",
-                output_statement="""
-                    out[i] = item;
-                    """)
-
-        out = cl_array.empty(queue, n, dtype=np.int32)
-        knl(out)
-
-        assert (out.get() == np.arange(n)).all()
+    dev = context.devices[0]
+    if dev.type == cl.device_type.CPU:
+        classes.append(GenericDebugScanKernel)
+
+    for cls in classes:
+        for n in scan_test_counts:
+            knl = cls(
+                    context, np.int32,
+                    arguments="__global int *out",
+                    input_expr="i",
+                    scan_expr="b", neutral="0",
+                    output_statement="""
+                        out[i] = item;
+                        """)
+
+            out = cl_array.empty(queue, n, dtype=np.int32)
+            knl(out)
+
+            assert (out.get() == np.arange(n)).all()
 
 @pytools.test.mark_test.opencl
 def test_segmented_scan(ctx_factory):
-- 
GitLab