diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 567d56eb10400559b6af080afb6edf79951d66ee..75e4a83e0dd904859b9d2633b9049dd5356598ef 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -41,7 +41,7 @@ from pyopencl._cluda import CLUDA_PREAMBLE
 
 
 
-SHARED_PREAMBLE = CLUDA_PREAMBLE + """
+SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL//
 #define WG_SIZE ${wg_size}
 
 /* SCAN_EXPR has no right know the indices it is scanning at because
@@ -49,6 +49,10 @@ each index may occur an undetermined number of times in the scan tree,
 and thus index-based side computations cannot be meaningful. */
 
 #define SCAN_EXPR(a, b) ${scan_expr}
+#define INPUT_EXPR(i) (${input_expr})
+%if is_segmented:
+    #define IS_SEG_START(i, a) (${is_i_segment_start_expr})
+%endif
 
 ${preamble}
 
@@ -66,11 +70,6 @@ typedef ${index_ctype} index_type;
 SCAN_INTERVALS_SOURCE = mako.template.Template(SHARED_PREAMBLE + """//CL//
 
 #define K ${k_group_size}
-%if is_segmented:
-    #define IS_SEG_START(i, a) (${is_i_segment_start_expr})
-%endif
-#define INPUT_EXPR(i) (${input_expr})
-
 
 KERNEL
 REQD_WG_SIZE(WG_SIZE, 1, 1)
@@ -437,7 +436,8 @@ void ${name_prefix}_final_update(
 
         if(i < interval_end)
         {
-            scan_type value = SCAN_EXPR(prev_group_sum, *partial_scan_buffer);
+            scan_type val = partial_scan_buffer[i];
+            scan_type value = SCAN_EXPR(prev_group_sum, val);
             OUTPUT_STMT(i, value)
         }
     }
@@ -450,9 +450,7 @@ void ${name_prefix}_final_update(
 
 EXCLUSIVE_UPDATE_SOURCE = mako.template.Template(SHARED_PREAMBLE + """//CL//
 
-        borked for now // FIXME
-
-#define OUTPUT_STMT(i, a) ${output_stmt}
+#define OUTPUT_STMT(i, a) ${output_statement}
 
 KERNEL
 REQD_WG_SIZE(WG_SIZE, 1, 1)
@@ -462,22 +460,22 @@ void ${name_prefix}_final_update(
     const index_type interval_size,
     GLOBAL_MEM scan_type *interval_results,
     GLOBAL_MEM scan_type *partial_scan_buffer
+    %if is_segmented:
+        , GLOBAL_MEM index_type *g_first_segment_start_in_interval
+    %endif
     )
 {
     LOCAL_MEM scan_type ldata[WG_SIZE];
 
     const index_type interval_begin = interval_size * GID_0;
-    const index_type interval_end   = min(interval_begin + interval_size, N);
+    const index_type interval_end = min(interval_begin + interval_size, N);
 
     // value to add to this segment
     scan_type carry = ${neutral};
     if(GID_0 != 0)
-    {
-        scan_type tmp = interval_results[GID_0 - 1];
-        carry = SCAN_EXPR(carry, tmp);
-    }
+        carry = interval_results[GID_0 - 1];
 
-    scan_type value = carry;
+    scan_type value = carry; // (A)
 
     for (index_type unit_base = interval_begin;
         unit_base < interval_end;
@@ -485,28 +483,39 @@ void ${name_prefix}_final_update(
     {
         const index_type i = unit_base + LID_0;
 
+        // load a work group's worth of data
         if (i < interval_end)
         {
-            scan_type tmp = interval_results[i];
+            scan_type tmp = partial_scan_buffer[i];
             ldata[LID_0] = SCAN_EXPR(carry, tmp);
         }
 
         local_barrier();
 
+        // perform right shift
         if (LID_0 != 0)
             value = ldata[LID_0 - 1];
         /*
-        else (see above)
-            value = carry OR last tail;
+        else 
+            value = carry (see (A)) OR last tail (see (B));
         */
 
+        %if is_segmented:
+        {
+            scan_type scan_item_at_i = INPUT_EXPR(i)
+            bool is_seg_start = IS_SEG_START(i, scan_item_at_i);
+            if (is_seg_start)
+                value = ${neutral};
+        }
+        %endif
+
         if (i < interval_end)
         {
             OUTPUT_STMT(i, value)
         }
 
         if(LID_0 == 0)
-            value = ldata[WG_SIZE - 1];
+            value = ldata[WG_SIZE - 1]; // (B)
 
         local_barrier();
     }
@@ -608,15 +617,21 @@ class _GenericScanKernelBase(object):
                 i for i, arg in enumerate(self.parsed_args)
                 if isinstance(arg, VectorArg)][0]
 
-        if partial_scan_buffer_name  is not None:
+        self.is_segmented = is_i_segment_start_expr is not None
+
+        if self.is_segmented and self.is_exclusive:
+            # The final update in segmented exclusive scan must be able to
+            # reconstruct where the segment boundaries were, and therefore
+            # can't overwrite any of the input.
+            partial_scan_buffer_name = None
+
+        if partial_scan_buffer_name is not None:
             self.partial_scan_buffer_idx, = [
                     i for i, arg in enumerate(self.parsed_args)
                     if arg.name == partial_scan_buffer_name]
         else:
             self.partial_scan_buffer_idx = None
 
-        self.is_segmented = is_i_segment_start_expr is not None
-
         # {{{ set up shared code dict
 
         from pytools import all
@@ -713,6 +728,8 @@ class _GenericScanKernelBase(object):
             wg_size=self.update_wg_size,
             output_statement=output_statement,
             argument_signature=arguments,
+            is_i_segment_start_expr=is_i_segment_start_expr,
+            input_expr=input_expr,
             **self.code_variables))
 
         final_update_prg = cl.Program(self.context, final_update_src).build(options)
@@ -805,8 +822,8 @@ class _GenericScanKernelBase(object):
         interval_size, num_intervals = uniform_interval_splitting(
                 n, unit_size, max_intervals)
 
-        print "n:%d interval_size: %d num_intervals: %d k_group_size:%d" % (
-                n, interval_size, num_intervals, l1_info.k_group_size)
+        #print "n:%d interval_size: %d num_intervals: %d k_group_size:%d" % (
+                #n, interval_size, num_intervals, l1_info.k_group_size)
 
         # {{{ first level scan of interval (one interval per block)
 
@@ -865,9 +882,11 @@ class _GenericScanKernelBase(object):
 
 class GenericInclusiveScanKernel(_GenericScanKernelBase):
     final_update_tp = INCLUSIVE_UPDATE_SOURCE
+    is_exclusive = False
 
 class GenericExclusiveScanKernel(_GenericScanKernelBase):
     final_update_tp = EXCLUSIVE_UPDATE_SOURCE
+    is_exclusive = True
 
 class _ScanKernelBase(_GenericScanKernelBase):
     def __init__(self, ctx, dtype,
@@ -914,8 +933,10 @@ class _ScanKernelBase(_GenericScanKernelBase):
 
 class InclusiveScanKernel(_ScanKernelBase):
     final_update_tp = INCLUSIVE_UPDATE_SOURCE
+    is_exclusive = False
 
 class ExclusiveScanKernel(_ScanKernelBase):
     final_update_tp = EXCLUSIVE_UPDATE_SOURCE
+    is_exclusive = True
 
 # vim: filetype=pyopencl:fdm=marker
diff --git a/test/test_array.py b/test/test_array.py
index 8e5a56ac4a56b872189b92239931f79027ef6e73..d3d2f52fc654e9c2d6c13271942fa39adc6cadc5 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -704,8 +704,8 @@ def summarize_error(obtained, desired, orig, thresh=1e-5):
                 entries.append("<%d ok>" % ok_count)
                 ok_count = 0
 
-            entries.append("%r (want: %r, diff: %r, orig: %r)" % (obtained[i], desired[i],
-                obtained[i]-desired[i], orig[i]))
+            entries.append("%r (want: %r, got: %r, orig: %r)" % (obtained[i], desired[i],
+                obtained[i], orig[i]))
         else:
             ok_count += 1
 
@@ -729,13 +729,15 @@ def test_scan(ctx_factory):
         knl = cls(context, dtype, "a+b", "0")
 
         for n in [
-            10, 2 ** 10 - 5, 2 ** 10,
-            2 ** 20 - 2 ** 18,
-            2 ** 20 - 2 ** 18 + 5,
-            2 ** 10 + 5,
-            2 ** 20 + 1,
-            2 ** 20, 2 ** 24
-            ]:
+                10,
+                2 ** 10 - 5,
+                2 ** 10,
+                2 ** 10 + 5,
+                2 ** 20 - 2 ** 18,
+                2 ** 20 - 2 ** 18 + 5,
+                2 ** 20 + 1,
+                2 ** 20, 2 ** 24
+                ]:
 
             host_data = np.random.randint(0, 10, n).astype(dtype)
             dev_data = cl_array.to_device(queue, host_data)
@@ -752,6 +754,7 @@ def test_scan(ctx_factory):
             if 0 and not is_ok:
                 print(summarize_error(dev_data.get(), desired_result, host_data))
 
+            print n, is_ok
             assert is_ok
             from gc import collect
             collect()