From 539c1f990870091efd51597e9cba2c92e2a72a3f Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 23 Jul 2012 13:29:37 -0500
Subject: [PATCH] Segmented scan passes all tests.

---
 pyopencl/scan.py   | 134 +++++++++++++++++++++++----------------------
 test/test_array.py |  16 +++---
 2 files changed, 79 insertions(+), 71 deletions(-)

diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 0aaabb91..136c789c 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -45,6 +45,8 @@ from pyopencl.tools import context_dependent_memoize
 # {{{ preamble
 
 SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL//
+#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
+
 #define WG_SIZE ${wg_size}
 
 /* SCAN_EXPR has no right know the indices it is scanning at because
@@ -135,7 +137,7 @@ void ${name_prefix}_scan_intervals(
         , GLOBAL_MEM index_type *g_first_segment_start_in_interval
     %endif
     %if store_segment_start_flags:
-        , GLOBAL_MEM index_type *g_segment_start_flags
+        , GLOBAL_MEM char *g_segment_start_flags
     %endif
     )
 {
@@ -144,7 +146,7 @@ void ${name_prefix}_scan_intervals(
     LOCAL_MEM scan_type ldata[K + 1][WG_SIZE + 1];
 
     %if is_segmented:
-        LOCAL_MEM bool l_segment_start_flags[K][WG_SIZE];
+        LOCAL_MEM char l_segment_start_flags[K][WG_SIZE];
         LOCAL_MEM index_type l_first_segment_start_in_subtree[WG_SIZE];
 
         // only relevant/populated for local id 0
@@ -513,7 +515,7 @@ void ${name_prefix}_final_update(
         , GLOBAL_MEM index_type *g_first_segment_start_in_interval
     %endif
     %if is_segmented and use_lookbehind_update:
-        , GLOBAL_MEM bool *g_segment_start_flags
+        , GLOBAL_MEM char *g_segment_start_flags
     %endif
     )
 {
@@ -521,7 +523,7 @@ void ${name_prefix}_final_update(
         LOCAL_MEM scan_type ldata[WG_SIZE];
     %endif
     %if is_segmented and use_lookbehind_update:
-        LOCAL_MEM scan_type l_segment_start_flags[WG_SIZE];
+        LOCAL_MEM char l_segment_start_flags[WG_SIZE];
     %endif
 
     const index_type interval_begin = interval_size * GID_0;
@@ -532,12 +534,18 @@ void ${name_prefix}_final_update(
     if (GID_0 != 0)
         carry = interval_results[GID_0 - 1];
 
+    %if is_segmented:
+        const index_type first_seg_start_in_interval =
+            g_first_segment_start_in_interval[GID_0];
+
+    %endif
+
     %if not use_lookbehind_update:
         // {{{ no look-behind ('prev_item' not in output_statement -> simpler)
 
         index_type update_i = interval_begin+LID_0;
 
-        <%def name="update_loop_plain(end, phase)">
+        <%def name="update_loop_plain(end)">
             for(; update_i < ${end}; update_i += WG_SIZE)
             {
                 scan_type partial_val = partial_scan_buffer[update_i];
@@ -547,7 +555,17 @@ void ${name_prefix}_final_update(
             }
         </%def>
 
-        <% update_loop = self.update_loop_plain %>
+        // {{{ update to the first intra-interval segment boundary
+
+        %if is_segmented:
+            index_type seg_end = min(first_seg_start_in_interval, interval_end);
+            ${update_loop_plain("seg_end")}
+            carry = ${neutral};
+        %endif
+
+        // }}}
+
+        ${update_loop_plain("interval_end")}
 
         // }}}
     %else:
@@ -560,71 +578,56 @@ void ${name_prefix}_final_update(
         index_type group_base = interval_begin;
         scan_type prev_value = carry; // (A)
 
-        <%def name="update_loop_lookbehind(end, phase)">
-            for(; group_base < ${end}; group_base += WG_SIZE)
-            {
-                index_type update_i = group_base+LID_0;
-
-                // load a work group's worth of data
-                if (update_i < ${end})
-                {
-                    scan_type tmp = partial_scan_buffer[update_i];
-                    ldata[LID_0] = SCAN_EXPR(carry, tmp);
-                    %if is_segmented:
-                        l_segment_start_flags[LID_0] = g_segment_start_flags[update_i];
-                    %endif
-                }
-
-                local_barrier();
+        for(; group_base < interval_end; group_base += WG_SIZE)
+        {
+            index_type update_i = group_base+LID_0;
 
-                // find prev_value
-                if (LID_0 != 0)
-                    prev_value = ldata[LID_0 - 1];
-                /*
-                else
-                    prev_value = carry (see (A)) OR last tail (see (B));
-                */
+            // load a work group's worth of data
+            if (update_i < interval_end)
+            {
+                scan_type tmp = partial_scan_buffer[update_i];
 
-                if (update_i < ${end})
-                {
-                    %if is_segmented:
-                        if (l_segment_start_flags[LID_0])
-                            prev_value = ${neutral};
-                    %endif
+                %if is_segmented:
+                if (update_i < first_seg_start_in_interval)
+                %endif
+                { tmp = SCAN_EXPR(carry, tmp); }
 
-                    scan_type value = ldata[LID_0];
-                    OUTPUT_STMT(update_i, prev_value, value)
-                }
+                ldata[LID_0] = tmp;
 
-                if (LID_0 == 0)
-                    prev_value = ldata[WG_SIZE - 1]; // (B)
-
-                local_barrier();
+                %if is_segmented:
+                    l_segment_start_flags[LID_0] = g_segment_start_flags[update_i];
+                %endif
             }
-        </%def>
 
-        // FIXME TAKE CARE OF BLOCK RESYNCING
+            local_barrier();
 
-        <% update_loop = self.update_loop_lookbehind %>
+            // find prev_value
+            if (LID_0 != 0)
+                prev_value = ldata[LID_0 - 1];
+            /*
+            else
+                prev_value = carry (see (A)) OR last tail (see (B));
+            */
 
-        // }}}
-    %endif
-
-    %if is_segmented:
-        // {{{ update to the first intra-interval segment boundary
+            if (update_i < interval_end)
+            {
+                %if is_segmented:
+                    if (l_segment_start_flags[LID_0])
+                        prev_value = ${neutral};
+                %endif
 
-        const index_type first_seg_start_in_interval =
-            g_first_segment_start_in_interval[GID_0];
+                scan_type value = ldata[LID_0];
+                OUTPUT_STMT(update_i, prev_value, value)
+            }
 
-        index_type seg_end = min(first_seg_start_in_interval, interval_end);
-        ${update_loop("seg_end", phase="first_seg")}
+            if (LID_0 == 0)
+                prev_value = ldata[WG_SIZE - 1]; // (B)
 
-        carry = ${neutral};
+            local_barrier();
+        }
 
         // }}}
     %endif
-
-    ${update_loop("interval_end", phase="remainder")}
 }
 """
 
@@ -674,24 +677,26 @@ _PREFIX_WORDS = set("""
         l_ o_mod_k o_div_k l_segment_start_flags scan_value sum
         first_seg_start_in_interval g_segment_start_flags
         group_base seg_end
+
+        LID_2 LID_1 LID_0
+        LDIM_0 LDIM_1 LDIM_2
+        GDIM_0 GDIM_1 GDIM_2
+        GID_0 GID_1 GID_2
         """.split())
 
 _IGNORED_WORDS = set("""
         typedef for endfor if void while endwhile endfor endif else const printf
-        None return bool n
+        None return bool n char
 
         set iteritems len setdefault
 
-        LID_2 LID_1 LID_0
-        LDIM_0 LDIM_1 LDIM_2
-        GDIM_0 GDIM_1 GDIM_2
-        GID_0 GID_1 GID_2
         GLOBAL_MEM LOCAL_MEM_ARG WITHIN_KERNEL LOCAL_MEM KERNEL REQD_WG_SIZE
         local_barrier
         CLK_LOCAL_MEM_FENCE OPENCL EXTENSION
         pragma __attribute__ __global __kernel __local
         get_local_size get_local_id cl_khr_fp64 reqd_work_group_size
         get_num_groups barrier get_group_id
+        cl_khr_byte_addressable_store
 
         _final_update _scan_intervals
 
@@ -961,7 +966,8 @@ class GenericScanKernel(object):
             use_lookbehind_update=use_lookbehind_update,
             **self.code_variables))
 
-        with open("update.cl", "wt") as f: f.write(final_update_src)
+        #with open("update.cl", "wt") as f: f.write(final_update_src)
+
         final_update_prg = cl.Program(self.context, final_update_src).build(options)
         self.final_update_knl = getattr(
                 final_update_prg,
@@ -1011,7 +1017,7 @@ class GenericScanKernel(object):
             store_segment_start_flags=store_segment_start_flags,
             **self.code_variables))
 
-        with open("scan-lev%d.cl" % (1 if is_first_level else 2), "wt") as f: f.write(scan_src)
+        #with open("scan-lev%d.cl" % (1 if is_first_level else 2), "wt") as f: f.write(scan_src)
 
         prg = cl.Program(self.context, scan_src).build(self.options)
 
diff --git a/test/test_array.py b/test/test_array.py
index 247ac996..18768689 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -807,7 +807,7 @@ def test_scan(ctx_factory):
             if 1 and not is_ok:
                 print(summarize_error(dev_data.get(), desired_result, host_data))
 
-            print n, is_ok
+            print(n, is_ok)
             assert is_ok
             from gc import collect
             collect()
@@ -882,8 +882,8 @@ def test_segmented_scan(ctx_factory):
     dtype = np.int32
     ctype = dtype_to_ctype(dtype)
 
-    for is_exclusive in [False, True]:
-    #for is_exclusive in [True, False]:
+    #for is_exclusive in [False, True]:
+    for is_exclusive in [True, False]:
         if is_exclusive:
             output_statement = "out[i] = prev_item"
         else:
@@ -924,6 +924,7 @@ def test_segmented_scan(ctx_factory):
 
             for seg_boundaries in seg_boundaries_values:
                 #print "BOUNDARIES", seg_boundaries
+                #print a
 
                 seg_boundary_flags = np.zeros(n, dtype=np.uint8)
                 seg_boundary_flags[seg_boundaries] = 1
@@ -940,11 +941,11 @@ def test_segmented_scan(ctx_factory):
 
                     if is_exclusive:
                         result_host[seg_start+1:seg_end] = np.cumsum(
-                                result_host[seg_start:seg_end][:-1])
+                                a[seg_start:seg_end][:-1])
                         result_host[seg_start] = 0
                     else:
                         result_host[seg_start:seg_end] = np.cumsum(
-                                result_host[seg_start:seg_end])
+                                a[seg_start:seg_end])
 
                 #print "REF", result_host
 
@@ -955,12 +956,13 @@ def test_segmented_scan(ctx_factory):
                 is_correct = (result_dev.get() == result_host).all()
                 if not is_correct:
                     diff = result_dev.get() - result_host
-                    print "RES-REF", diff
+                    print("RES-REF", diff)
                     print("ERRWHERE", np.where(diff))
                     print(n, list(seg_boundaries))
 
                 assert is_correct
-            print(n, "done")
+
+            print("%d excl:%s done" % (n, is_exclusive))
 
 
 
-- 
GitLab