From 94eb9a38c60dd50d1016fe0a972dc35e60544db6 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sat, 10 Nov 2012 02:19:59 -0600
Subject: [PATCH] More scan performance tweaks.

---
 pyopencl/scan.py | 50 ++++++++++++++++++++++++++++++++----------------
 1 file changed, 33 insertions(+), 17 deletions(-)

diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index ae254998..6c144961 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -143,9 +143,25 @@ void ${name_prefix}_scan_intervals(
     %endif
     )
 {
-    // padded in WG_SIZE to avoid bank conflicts
     // index K in first dimension used for carry storage
-    LOCAL_MEM scan_type ldata[K + 1][WG_SIZE + 1];
+    %if scan_dtype.itemsize > 4 and scan_dtype.itemsize % 8 == 0:
+        // Avoid bank conflicts by adding a single 32-bit value to the size of
+        // the scan type.
+        struct __attribute__ ((__packed__)) wrapped_scan_type
+        {
+            scan_type value;
+            int dummy;
+        };
+        LOCAL_MEM struct wrapped_scan_type ldata[K + 1][WG_SIZE + 1];
+    %else:
+        struct wrapped_scan_type
+        {
+            scan_type value;
+        };
+
+        // padded in WG_SIZE to avoid bank conflicts
+        LOCAL_MEM struct wrapped_scan_type ldata[K + 1][WG_SIZE];
+    %endif
 
     %if is_segmented:
         LOCAL_MEM char l_segment_start_flags[K][WG_SIZE];
@@ -253,7 +269,7 @@ void ${name_prefix}_scan_intervals(
 
                     const index_type o_mod_k = offset % K;
                     const index_type o_div_k = offset / K;
-                    ldata[o_mod_k][offset / K] = scan_value;
+                    ldata[o_mod_k][offset / K].value = scan_value;
 
                     %if is_segmented:
                         bool is_seg_start = IS_SEG_START(read_i, scan_value);
@@ -281,7 +297,7 @@ void ${name_prefix}_scan_intervals(
 
             if (LID_0 == 0 && unit_base != interval_begin)
             {
-                ldata[0][0] = SCAN_EXPR(ldata[K][WG_SIZE - 1], ldata[0][0],
+                ldata[0][0].value = SCAN_EXPR(ldata[K][WG_SIZE - 1].value, ldata[0][0].value,
                     %if is_segmented:
                         (l_segment_start_flags[0][0])
                     %else:
@@ -298,7 +314,7 @@ void ${name_prefix}_scan_intervals(
 
             // {{{ scan along k (sequentially in each work item)
 
-            scan_type sum = ldata[0][LID_0];
+            scan_type sum = ldata[0][LID_0].value;
 
             %if is_tail:
                 const index_type offset_end = interval_end - unit_base;
@@ -310,7 +326,7 @@ void ${name_prefix}_scan_intervals(
                 if (K * LID_0 + k < offset_end)
                 %endif
                 {
-                    scan_type tmp = ldata[k][LID_0];
+                    scan_type tmp = ldata[k][LID_0].value;
                     index_type seq_i = unit_base + K*LID_0 + k;
 
                     %if is_segmented:
@@ -330,7 +346,7 @@ void ${name_prefix}_scan_intervals(
                         %endif
                         );
 
-                    ldata[k][LID_0] = sum;
+                    ldata[k][LID_0].value = sum;
                 }
             }
 
@@ -339,7 +355,7 @@ void ${name_prefix}_scan_intervals(
             // }}}
 
             // store carry in out-of-bounds (padding) array entry (index K) in the K direction
-            ldata[K][LID_0] = sum;
+            ldata[K][LID_0].value = sum;
 
             %if is_segmented:
                 l_first_segment_start_in_subtree[LID_0] = first_segment_start_in_k_group;
@@ -361,7 +377,7 @@ void ${name_prefix}_scan_intervals(
             // across k groups, along local id
             // (uses out-of-bounds k=K array entry for storage)
 
-            scan_type val = ldata[K][LID_0];
+            scan_type val = ldata[K][LID_0].value;
 
             <% scan_offset = 1 %>
 
@@ -370,7 +386,7 @@ void ${name_prefix}_scan_intervals(
 
                 if (LID_0 >= ${scan_offset})
                 {
-                    scan_type tmp = ldata[K][LID_0 - ${scan_offset}];
+                    scan_type tmp = ldata[K][LID_0 - ${scan_offset}].value;
                     % if is_tail:
                     if (K*LID_0 < offset_end)
                     % endif
@@ -409,7 +425,7 @@ void ${name_prefix}_scan_intervals(
 
                 // {{{ writes to local allowed, reads from local not allowed
 
-                ldata[K][LID_0] = val;
+                ldata[K][LID_0].value = val;
                 %if is_segmented:
                     l_first_segment_start_in_subtree[LID_0] =
                         first_segment_start_in_subtree;
@@ -445,7 +461,7 @@ void ${name_prefix}_scan_intervals(
 
             if (LID_0 > 0)
             {
-                sum = ldata[K][LID_0 - 1];
+                sum = ldata[K][LID_0 - 1].value;
 
                 for(index_type k = 0; k < K; k++)
                 {
@@ -453,8 +469,8 @@ void ${name_prefix}_scan_intervals(
                     if (K * LID_0 + k < offset_end)
                     %endif
                     {
-                        scan_type tmp = ldata[k][LID_0];
-                        ldata[k][LID_0] = SCAN_EXPR(sum, tmp,
+                        scan_type tmp = ldata[k][LID_0].value;
+                        ldata[k][LID_0].value = SCAN_EXPR(sum, tmp,
                             %if is_segmented:
                                 (unit_base + K * LID_0 + k
                                     >= first_segment_start_in_k_group)
@@ -517,7 +533,7 @@ void ${name_prefix}_scan_intervals(
                         index_type remainder = linear_index - linear_scan_data_idx * scan_types_per_int;
 
                         __local int *src = (__local int *) &(
-                            ldata[linear_scan_data_idx % K][linear_scan_data_idx / K]);
+                            ldata[linear_scan_data_idx % K][linear_scan_data_idx / K].value);
 
                         dest[linear_index] = src[remainder];
                     }
@@ -1190,9 +1206,9 @@ class GenericScanKernel(_GenericScanKernelBase):
             store_segment_start_flags, k_group_size):
         scalar_arg_dtypes = _get_scalar_arg_dtypes(arguments)
 
-        # Thrust says that 128 is big enough for GT200
+        # Empirically found on Nv hardware: no need to be bigger than this size
         wg_size = _round_down_to_power_of_2(
-                min(max_wg_size, 128))
+                min(max_wg_size, 256))
 
         scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
         scan_src = str(scan_tpl.render(
-- 
GitLab