diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index d441dd49f57b9652c1b577db04ae03b6b1f9acd8..5ba79d6aa4fa810dca180f67362bf3da48505fe3 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -55,12 +55,12 @@ SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL//
 
 ${preamble}
 
-typedef ${scan_ctype} scan_type;
-typedef ${index_ctype} index_type;
+typedef ${dtype_to_ctype(scan_dtype)} scan_type;
+typedef ${dtype_to_ctype(index_dtype)} index_type;
 
 // NO_SEG_BOUNDARY is the largest representable integer in index_type.
 // This assumption is used in code below.
-#define NO_SEG_BOUNDARY ${index_type_max}
+#define NO_SEG_BOUNDARY ${str(np.iinfo(index_dtype).max)}
 """
 
 # }}}
@@ -484,18 +484,44 @@ void ${name_prefix}_scan_intervals(
 
             // {{{ write data
 
-            for (index_type k = 0; k < K; k++)
+            // work hard with index math to achieve contiguous 32-bit stores
             {
-                const index_type offset = k*WG_SIZE + LID_0;
+                __global int *dest = (__global int *) (partial_scan_buffer + unit_base);
 
-                %if is_tail:
-                if (unit_base + offset < interval_end)
-                %endif
-                {
-                    pycl_printf(("write: %d\n", unit_base + offset));
-                    partial_scan_buffer[unit_base + offset] =
-                        ldata[offset % K][offset / K];
-                }
+                <%
+
+                assert scan_dtype.itemsize % 4 == 0
+
+                ints_per_wg = wg_size
+                ints_to_store = scan_dtype.itemsize*wg_size*k_group_size // 4
+
+                %>
+
+                const index_type scan_types_per_int = ${scan_dtype.itemsize//4};
+
+                %for store_base in xrange(0, ints_to_store, ints_per_wg):
+                    <%
+
+                    # Observe that ints_to_store is divisible by the work group size already,
+                    # so we won't go out of bounds that way.
+                    assert store_base + ints_per_wg <= ints_to_store
+
+                    %>
+
+                    %if is_tail:
+                    if (${store_base} + LID_0 < scan_types_per_int*(interval_end - unit_base))
+                    %endif
+                    {
+                        index_type linear_index = ${store_base} + LID_0;
+                        index_type linear_scan_data_idx = linear_index / scan_types_per_int;
+                        index_type remainder = linear_index - linear_scan_data_idx * scan_types_per_int;
+
+                        __local int *src = (__local int *) &(
+                            ldata[linear_scan_data_idx % K][linear_scan_data_idx / K]);
+
+                        dest[linear_index] = src[remainder];
+                    }
+                %endfor
             }
 
             pycl_printf(("after write\n"));
@@ -508,15 +534,15 @@ void ${name_prefix}_scan_intervals(
     % endfor
 
     // write interval sum
-    if (LID_0 == 0)
-    {
-        %if is_first_level:
-        interval_results[GID_0] = partial_scan_buffer[interval_end - 1];
-        %endif
-        %if is_segmented and is_first_level:
-            g_first_segment_start_in_interval[GID_0] = first_segment_start_in_interval;
-        %endif
-    }
+    %if is_first_level:
+        if (LID_0 == 0)
+        {
+            interval_results[GID_0] = partial_scan_buffer[interval_end - 1];
+            %if is_segmented:
+                g_first_segment_start_in_interval[GID_0] = first_segment_start_in_interval;
+            %endif
+        }
+    %endif
 }
 """
 
@@ -713,6 +739,8 @@ _PREFIX_WORDS = set("""
         l_ o_mod_k o_div_k l_segment_start_flags scan_value sum
         first_seg_start_in_interval g_segment_start_flags
         group_base seg_end my_val DEBUG ARGS
+        ints_to_store ints_per_wg scan_types_per_int linear_index
+        linear_scan_data_idx dest src store_base
 
         LID_2 LID_1 LID_0
         LDIM_0 LDIM_1 LDIM_2
@@ -721,8 +749,11 @@ _PREFIX_WORDS = set("""
         """.split())
 
 _IGNORED_WORDS = set("""
+        4 32
+
         typedef for endfor if void while endwhile endfor endif else const printf
-        None return bool n char true false ifdef pycl_printf
+        None return bool n char true false ifdef pycl_printf str xrange assert
+        np iinfo max itemsize
 
         set iteritems len setdefault
 
@@ -754,17 +785,18 @@ _IGNORED_WORDS = set("""
         intra Therefore find code assumption
         branch workgroup complicated granularity phase remainder than simpler
         We smaller look ifs lots self behind allow barriers whole loop
-        after
+        after already Observe achieve contiguous stores hard go with by math
+        size won t way divisible bit so
 
         is_tail is_first_level input_expr argument_signature preamble
-        double_support neutral output_statement index_type_max
-        k_group_size name_prefix is_segmented index_ctype scan_ctype
+        double_support neutral output_statement
+        k_group_size name_prefix is_segmented index_dtype scan_dtype
         wg_size is_segment_start_expr fetch_expr_offsets
         arg_ctypes ife_offsets input_fetch_exprs def
         ife_offset arg_name local_fetch_expr_args update_body
         update_loop_lookbehind update_loop_plain update_loop
         use_lookbehind_update store_segment_start_flags
-        update_loop first_seg
+        update_loop first_seg scan_dtype dtype_to_ctype
 
         a b prev_item i last_item prev_value
         N NO_SEG_BOUNDARY across_seg_boundary
@@ -801,6 +833,9 @@ class _ScanKernelInfo(Record):
 
 # }}}
 
+class ScanPerformanceWarning(UserWarning):
+    pass
+
 class _GenericScanKernelBase(object):
     # {{{ constructor, argument processing
 
@@ -901,6 +936,9 @@ class _GenericScanKernelBase(object):
                     "'output_statement' otherwise does something non-trivial",
                     stacklevel=2)
 
+        if dtype.itemsize % 4 != 0:
+            raise TypeError("scan value type must have size divisible by 4 bytes")
+
         self.index_dtype = np.dtype(index_dtype)
         if np.iinfo(self.index_dtype).min >= 0:
             raise TypeError("index_dtype must be signed")
@@ -930,8 +968,10 @@ class _GenericScanKernelBase(object):
                 raise RuntimeError("input_fetch_expr offsets must either be 0 or -1")
         self.input_fetch_exprs = input_fetch_exprs
 
+        arg_dtypes = {}
         arg_ctypes = {}
         for arg in self.parsed_args:
+            arg_dtypes[arg.name] = arg.dtype
             arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype)
 
         self.options = options
@@ -943,12 +983,14 @@ class _GenericScanKernelBase(object):
         from pyopencl.characterize import has_double_support
 
         self.code_variables = dict(
+            np=np,
+            dtype_to_ctype=dtype_to_ctype,
             preamble=preamble,
             name_prefix=name_prefix,
-            index_ctype=dtype_to_ctype(self.index_dtype),
-            index_type_max=str(np.iinfo(self.index_dtype).max),
-            scan_ctype=dtype_to_ctype(dtype),
+            index_dtype=self.index_dtype,
+            scan_dtype=dtype,
             is_segmented=self.is_segmented,
+            arg_dtypes=arg_dtypes,
             arg_ctypes=arg_ctypes,
             scan_expr=_process_code_for_macro(scan_expr),
             neutral=_process_code_for_macro(neutral),
@@ -975,7 +1017,6 @@ class GenericScanKernel(_GenericScanKernelBase):
 
         trip_count = 0
 
-
         if self.devices[0].type == cl.device_type.CPU:
             # (about the widest vector a CPU can support, also taking
             # into account that CPUs don't hide latency by large work groups
@@ -1002,6 +1043,19 @@ class GenericScanKernel(_GenericScanKernelBase):
                     wg_size, k_group_size) + 256  <= avail_local_mem):
                     solutions.append((wg_size*k_group_size, k_group_size, wg_size))
 
+        if self.devices[0].type == cl.device_type.GPU:
+            from pytools import any
+            for wg_size_floor in [256, 192, 128]:
+                have_sol_above_floor = any(wg_size >= wg_size_floor
+                        for _, _, wg_size in solutions)
+
+                if have_sol_above_floor:
+                    # delete all solutions not meeting the wg size floor
+                    solutions = [(total, k_group_size, wg_size)
+                            for total, k_group_size, wg_size in solutions
+                            if wg_size >= wg_size_floor]
+                    break
+
         _, k_group_size, max_scan_wg_size = max(solutions)
 
         while True:
@@ -1140,7 +1194,6 @@ class GenericScanKernel(_GenericScanKernelBase):
         wg_size = _round_down_to_power_of_2(
                 min(max_wg_size, 128))
 
-
         scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
         scan_src = str(scan_tpl.render(
             wg_size=wg_size,