From 58356197ff0f2ef9d288357b60c8a7387512fed0 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 29 Jul 2012 20:40:17 -0400
Subject: [PATCH] Shift responsibility for segmentation to scan_expr.

---
 pyopencl/scan.py   | 139 ++++++++++++++++++++++++++-------------------
 test/test_array.py |   5 +-
 2 files changed, 84 insertions(+), 60 deletions(-)

diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index d4744e5c..6565682f 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -47,11 +47,7 @@ from pyopencl.tools import context_dependent_memoize
 SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL//
 #define WG_SIZE ${wg_size}
 
-/* SCAN_EXPR has no right know the indices it is scanning at because
-each index may occur an undetermined number of times in the scan tree,
-and thus index-based side computations cannot be meaningful. */
-
-#define SCAN_EXPR(a, b) ${scan_expr}
+#define SCAN_EXPR(a, b, across_seg_boundary) ${scan_expr}
 #define INPUT_EXPR(i) (${input_expr})
 %if is_segmented:
     #define IS_SEG_START(i, a) (${is_segment_start_expr})
@@ -272,13 +268,15 @@ void ${name_prefix}_scan_intervals(
                     first_segment_start_in_k_group = unit_base + K*LID_0;
             %endif
 
-            if (LID_0 == 0 && unit_base != interval_begin
-                %if is_segmented:
-                    && !l_segment_start_flags[0][0]
-                %endif
-                )
+            if (LID_0 == 0 && unit_base != interval_begin)
             {
-                ldata[0][0] = SCAN_EXPR(ldata[K][WG_SIZE - 1], ldata[0][0]);
+                ldata[0][0] = SCAN_EXPR(ldata[K][WG_SIZE - 1], ldata[0][0],
+                    %if is_segmented:
+                        (l_segment_start_flags[0][0])
+                    %else:
+                        false
+                    %endif
+                    );
             }
 
             // }}}
@@ -308,11 +306,16 @@ void ${name_prefix}_scan_intervals(
                         first_segment_start_in_k_group = min(
                             first_segment_start_in_k_group,
                             seq_i);
-                        sum = tmp;
                     }
-                    else
                     %endif
-                        sum = SCAN_EXPR(sum, tmp);
+
+                    sum = SCAN_EXPR(sum, tmp,
+                        %if is_segmented:
+                            (l_segment_start_flags[k][LID_0])
+                        %else:
+                            false
+                        %endif
+                        );
 
                     ldata[k][LID_0] = sum;
                 }
@@ -357,12 +360,13 @@ void ${name_prefix}_scan_intervals(
                     if (K*LID_0 < offset_end)
                     % endif
                     {
-                        %if is_segmented:
-                            if (l_first_segment_start_in_subtree[LID_0] == NO_SEG_BOUNDARY)
-                                val = SCAN_EXPR(tmp, val);
-                        %else:
-                            val = SCAN_EXPR(tmp, val);
-                        %endif
+                        val = SCAN_EXPR(tmp, val,
+                            %if is_segmented:
+                                (l_first_segment_start_in_subtree[LID_0] != NO_SEG_BOUNDARY)
+                            %else:
+                                false
+                            %endif
+                            );
                     }
 
                     %if is_segmented:
@@ -428,19 +432,19 @@ void ${name_prefix}_scan_intervals(
 
                 for(index_type k = 0; k < K; k++)
                 {
-                    bool do_update = true;
                     %if is_tail:
-                        do_update = K * LID_0 + k < offset_end;
+                    if (K * LID_0 + k < offset_end)
                     %endif
-                    %if is_segmented:
-                        do_update = unit_base + K * LID_0 + k
-                            < first_segment_start_in_k_group;
-                    %endif
-
-                    if (do_update)
                     {
                         scan_type tmp = ldata[k][LID_0];
-                        ldata[k][LID_0] = SCAN_EXPR(sum, tmp);
+                        ldata[k][LID_0] = SCAN_EXPR(sum, tmp,
+                            %if is_segmented:
+                                (unit_base + K * LID_0 + k
+                                    >= first_segment_start_in_k_group)
+                            %else:
+                                false
+                            %endif
+                            );
                     }
                 }
             }
@@ -542,28 +546,24 @@ void ${name_prefix}_final_update(
 
         index_type update_i = interval_begin+LID_0;
 
-        <%def name="update_loop_plain(end)">
-            for(; update_i < ${end}; update_i += WG_SIZE)
-            {
-                scan_type partial_val = partial_scan_buffer[update_i];
-                scan_type item = SCAN_EXPR(carry, partial_val);
-                index_type i = update_i;
-
-                { ${output_statement}; }
-            }
-        </%def>
-
-        // {{{ update to the first intra-interval segment boundary
-
         %if is_segmented:
             index_type seg_end = min(first_seg_start_in_interval, interval_end);
-            ${update_loop_plain("seg_end")}
-            carry = ${neutral};
         %endif
 
-        // }}}
+        for(; update_i < interval_end; update_i += WG_SIZE)
+        {
+            scan_type partial_val = partial_scan_buffer[update_i];
+            scan_type item = SCAN_EXPR(carry, partial_val,
+                %if is_segmented:
+                    (update_i >= seg_end)
+                %else:
+                    false
+                %endif
+                );
+            index_type i = update_i;
 
-        ${update_loop_plain("interval_end")}
+            { ${output_statement}; }
+        }
 
         // }}}
     %else:
@@ -585,10 +585,13 @@ void ${name_prefix}_final_update(
             {
                 scan_type tmp = partial_scan_buffer[update_i];
 
-                %if is_segmented:
-                if (update_i < first_seg_start_in_interval)
-                %endif
-                { tmp = SCAN_EXPR(carry, tmp); }
+                tmp = SCAN_EXPR(carry, tmp,
+                    %if is_segmented:
+                        (update_i >= first_seg_start_in_interval)
+                    %else:
+                        false
+                    %endif
+                    );
 
                 ldata[LID_0] = tmp;
 
@@ -703,7 +706,7 @@ _PREFIX_WORDS = set("""
 
 _IGNORED_WORDS = set("""
         typedef for endfor if void while endwhile endfor endif else const printf
-        None return bool n char
+        None return bool n char true false
 
         set iteritems len setdefault
 
@@ -716,7 +719,7 @@ _IGNORED_WORDS = set("""
 
         _final_update _scan_intervals _debug_scan
 
-        positions all padded integer its previous write based true writes 0
+        positions all padded integer its previous write based writes 0
         has local worth scan_expr to read cannot not X items False bank
         four beginning follows applicable item min each indices works side
         scanning right summed relative used id out index avoid current state
@@ -747,7 +750,7 @@ _IGNORED_WORDS = set("""
         update_loop first_seg
 
         a b prev_item i prev_item_unavailable_with_local_update prev_value
-        N NO_SEG_BOUNDARY
+        N NO_SEG_BOUNDARY across_seg_boundary
         """.split())
 
 def _make_template(s):
@@ -807,6 +810,21 @@ class _GenericScanKernelBase(object):
             modified by the scan, it should live in `b`.
 
             This expression may call functions given in the *preamble*.
+
+            Another value available to this expression is `across_seg_boundary`,
+            a C `bool` indicating whether this scan update is crossing a
+            segment boundary, as defined by `is_segment_start_expr`.
+            The scan routine does not implement segmentation
+            semantics on its own. It relies on `scan_expr` to do this.
+            This value is available (but always `false`) even for a
+            non-segmented scan.
+
+            .. note::
+
+                In early pre-releases of the segmented scan,
+                segmentation semantics were implemented *without*
+                relying on `scan_expr`.
+
         :arg input_expr: A C expression, encoded as a string, resulting
             in the values to which the scan is applied. This may be used
             to apply a mapping to values stored in *arguments* before being
@@ -1230,13 +1248,16 @@ void ${name_prefix}_debug_scan(
 
         prev_item = item;
         %if is_segmented:
-            {
-                bool is_seg_start = IS_SEG_START(i, my_val);
-                if (is_seg_start)
-                    prev_item = ${neutral};
-            }
+            bool is_seg_start = IS_SEG_START(i, my_val);
         %endif
-        item = SCAN_EXPR(prev_item, my_val);
+
+        item = SCAN_EXPR(prev_item, my_val,
+            %if is_segmented:
+                is_seg_start
+            %else:
+                false
+            %endif
+            );
 
         {
             ${output_statement};
diff --git a/test/test_array.py b/test/test_array.py
index bef04fb9..694f5a51 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -957,7 +957,7 @@ def test_segmented_scan(ctx_factory):
                 arguments="__global %s *ary, __global char *segflags, __global %s *out"
                     % (ctype, ctype),
                 input_expr="ary[i]",
-                scan_expr="a+b", neutral="0",
+                scan_expr="across_seg_boundary ? b : (a+b)", neutral="0",
                 is_segment_start_expr="segflags[i]",
                 output_statement=output_statement,
                 options=[])
@@ -1117,6 +1117,9 @@ def test_view(ctx_factory):
 
 @pytools.test.mark_test.opencl
 def no_test_slice(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
     from pyopencl.clrandom import rand as clrand
 
     l = 20000
-- 
GitLab