diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 0db3a2187563e3df31d1eacb659711baebcedb10..35a23cfa3c2de6154188a491cd6a7849988703ee 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -138,12 +138,29 @@ void ${name_prefix}_scan_intervals(
     // index K in first dimension used for carry storage
     LOCAL_MEM scan_type ldata[K + 1][WG_SIZE + 1];
 
+    // {{{ set up local data for input_fetch_exprs if any of them are stenciled
+
+    <%
+        fetch_expr_offsets = {}
+        for name, arg_name, ife_offset in input_fetch_exprs:
+            fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
+
+        local_fetch_expr_args = set(
+            arg_name
+            for arg_name, ife_offsets in fetch_expr_offsets.iteritems()
+            if -1 in ife_offsets or len(ife_offsets) > 1)
+    %>
+
+    %for arg_name in local_fetch_expr_args:
+        LOCAL_MEM ${arg_ctypes[arg_name]} l_${arg_name}[WG_SIZE*K];
+    %endfor
+
+    // }}}
+
     %if is_segmented:
         index_type first_segment_start_in_interval = NO_SEG_BOUNDARY;
         LOCAL_MEM index_type l_first_segment_start_in_k_group[WG_SIZE];
         index_type first_segment_start_in_k_group;
-
-        if (LID_0 == 0)
     %endif
 
 
@@ -164,25 +181,67 @@ void ${name_prefix}_scan_intervals(
 
         {
 
+            // {{{ carry out input_fetch_exprs
+            // (if there are ones that need to be fetched into local)
+
+            %if local_fetch_expr_args:
+                for(index_type k = 0; k < K; k++)
+                {
+                    const index_type offset = k*WG_SIZE + LID_0;
+                    const index_type read_i = unit_base + offset;
+
+                    %for arg_name in local_fetch_expr_args:
+                        %if is_tail:
+                        if (read_i < interval_end)
+                        %endif
+                        {
+                            l_${arg_name}[offset] = ${arg_name}[read_i];
+                        }
+                    %endfor
+                }
+
+                local_barrier();
+            %endif
+
+            // }}}
+
             // {{{ read a unit's worth of data from global
 
             for(index_type k = 0; k < K; k++)
             {
                 const index_type offset = k*WG_SIZE + LID_0;
-
                 const index_type read_i = unit_base + offset;
 
                 %if is_tail:
                 if (read_i < interval_end)
                 %endif
                 {
+                    %for name, arg_name, ife_offset in input_fetch_exprs:
+                        ${arg_ctypes[arg_name]} ${name};
+
+                        %if arg_name in local_fetch_expr_args:
+                            if (offset + ${ife_offset} >= 0)
+                                ${name} = l_${arg_name}[offset + ${ife_offset}];
+                            else if (read_i + ${ife_offset} >= 0)
+                                ${name} = ${arg_name}[read_i + ${ife_offset}];
+                            /*
+                            else
+                                if out of bounds, name is left undefined */
+
+                        %else:
+                            // ${arg_name} gets fetched directly from global
+                            ${name} = ${arg_name}[read_i];
+
+                        %endif
+                    %endfor
+
                     ldata[offset % K][offset / K] = INPUT_EXPR(read_i);
                 }
             }
 
             // }}}
 
-            // {{{ carry in from previous unit, if applicable.
+            // {{{ carry in from previous unit, if applicable
 
             %if is_segmented:
                 if (LID_0 == 0 && unit_base != interval_begin)
@@ -461,7 +520,8 @@ void ${name_prefix}_final_update(
 
 # {{{ lookbehind update
 
-# used for exclusive scan or output_statements that request look-behind
+# used for exclusive scan or output_statements that request look-behind, i.e.
+# access to the preceding item.
 
 LOOKBEHIND_UPDATE_SOURCE = SHARED_PREAMBLE + """//CL//
 
@@ -580,13 +640,17 @@ _PREFIX_WORDS = set("""
         segment_start_in_subtree offset interval_results interval_end
         first_segment_start_in_subtree unit_base
         first_segment_start_in_interval k INPUT_EXPR
-        prev_group_sum prev pv this add value n partial_val pgs OUTPUT_STMT
+        prev_group_sum prev pv value n partial_val pgs OUTPUT_STMT
         is_seg_start update_i scan_item_at_i seq_i read_i
+        l_
         """.split())
 
 _IGNORED_WORDS = set("""
         typedef for endfor if void while endwhile endfor endif else const printf
         None return bool
+
+        set iteritems len setdefault
+
         LID_2 LID_1 LID_0
         LDIM_0 LDIM_1 LDIM_2
         GDIM_0 GDIM_1 GDIM_2
@@ -612,11 +676,16 @@ _IGNORED_WORDS = set("""
         storage data 1 largest may representable uses entry Y meaningful
         computations interval At the left dimension know d
         A load B group perform shift tail see last OR
+        this add fetched into are directly need
+        gets them stenciled that undefined
+        there up any ones or name
 
         is_tail is_first_level input_expr argument_signature preamble
         double_support neutral output_statement index_type_max
         k_group_size name_prefix is_segmented index_ctype scan_ctype
-        wg_size is_i_segment_start_expr
+        wg_size is_i_segment_start_expr fetch_expr_offsets
+        arg_ctypes ife_offsets input_fetch_exprs
+        ife_offset arg_name local_fetch_expr_args
 
         a b prev_item i prev_item_unavailable_with_local_update prev_value
         N
@@ -643,7 +712,7 @@ def _make_template(s):
 
     if leftovers:
         from warnings import warn
-        warn("leftover words in identifier prefixing:" + " ".join(leftovers))
+        warn("leftover words in identifier prefixing: " + " ".join(leftovers))
 
     return mako.template.Template(s, strict_undefined=True, disable_unicode=True)
 
@@ -656,8 +725,9 @@ class _ScanKernelInfo(Record):
 class GenericScanKernel(object):
     def __init__(self, ctx, dtype,
             arguments, input_expr, scan_expr, neutral, output_statement,
-            is_i_segment_start_expr=None,
+            is_i_segment_start_expr=None, input_fetch_exprs=[],
             partial_scan_buffer_name=None,
+            index_dtype=np.int32,
             name_prefix="scan", options=[], preamble="", devices=None):
         """
         :arg ctx: a :class:`pyopencl.Context` within which the code
@@ -685,6 +755,13 @@ class GenericScanKernel(object):
             scan. Has access to the current index `i` and the input element
             as `a` and returns a bool. If it returns true, then previous
             sums will not spill over into the item with index i.
+        :arg input_fetch_exprs: a list of tuples *(NAME, ARG_NAME, OFFSET)*.
+            An entry here has the effect of doing the equivalent of the following
+            before input_expr::
+
+                typeof(ARG_NAME) NAME = ARG_NAME[i+OFFSET];
+
+            OFFSET is allowed to be 0 or -1.
 
         The first array in the argument list determines the size of the index
         space over which the scan is carried out.
@@ -706,7 +783,9 @@ class GenericScanKernel(object):
                     "'output_statement' otherwise does something non-trivial",
                     stacklevel=2)
 
-        self.index_dtype = np.dtype(np.uint32)
+        self.index_dtype = np.dtype(index_dtype)
+        if np.iinfo(self.index_dtype).min >= 0:
+            raise TypeError("index_dtype must be signed")
 
         if devices is None:
             devices = ctx.devices
@@ -732,6 +811,22 @@ class GenericScanKernel(object):
             # can't overwrite any of the input.
             partial_scan_buffer_name = None
 
+            if self.input_fetch_exprs:
+                # FIXME need to insert code to handle input_fetch_exprs into
+                # the lookbehind update
+                raise NotImplementedError("input_fetch_exprs are not supported "
+                        "with a segmented scan using a look-behind update "
+                        "(e.g. an exclusive scan)")
+
+
+        for name, arg_name, ife_offset in input_fetch_exprs:
+            if ife_offset not in [0, -1]:
+                raise RuntimeError("input_fetch_expr offsets must either be 0 or -1")
+
+        arg_ctypes = {}
+        for arg in self.parsed_args:
+            arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype)
+
         if partial_scan_buffer_name is not None:
             self.partial_scan_buffer_idx, = [
                     i for i, arg in enumerate(self.parsed_args)
@@ -751,6 +846,7 @@ class GenericScanKernel(object):
             index_type_max=str(np.iinfo(self.index_dtype).max),
             scan_ctype=dtype_to_ctype(dtype),
             is_segmented=self.is_segmented,
+            arg_ctypes=arg_ctypes,
             scan_expr=scan_expr.replace("\n", " "),
             neutral=neutral.replace("\n", " "),
             double_support=all(
@@ -768,7 +864,9 @@ class GenericScanKernel(object):
         while True:
             candidate_scan_info = self.build_scan_kernel(
                     max_scan_wg_size, arguments, input_expr,
-                    is_i_segment_start_expr, is_first_level=True)
+                    is_i_segment_start_expr,
+                    input_fetch_exprs=input_fetch_exprs,
+                    is_first_level=True)
 
             # Will this device actually let us execute this kernel
             # at the desired work group size? Building it is the
@@ -795,7 +893,7 @@ class GenericScanKernel(object):
 
         # {{{ build second-level scan
 
-        second_level_arguments = [
+        second_level_arguments = self.arguments.split(",") + [
                 "__global %s *interval_sums" % dtype_to_ctype(dtype)]
         second_level_build_kwargs = {}
         if self.is_segmented:
@@ -816,6 +914,7 @@ class GenericScanKernel(object):
                 max_scan_wg_size,
                 arguments=", ".join(second_level_arguments),
                 input_expr="interval_sums[i]",
+                input_fetch_exprs=[],
                 is_first_level=False,
                 **second_level_build_kwargs)
 
@@ -856,7 +955,7 @@ class GenericScanKernel(object):
         # }}}
 
     def build_scan_kernel(self, max_wg_size, arguments, input_expr,
-            is_i_segment_start_expr, is_first_level):
+            is_i_segment_start_expr, input_fetch_exprs, is_first_level):
         scalar_arg_dtypes = _get_scalar_arg_dtypes(_parse_args(arguments))
 
         # Thrust says that 128 is big enough for GT200
@@ -879,6 +978,7 @@ class GenericScanKernel(object):
             k_group_size=k_group_size,
             argument_signature=arguments.replace("\n", " "),
             is_i_segment_start_expr=is_i_segment_start_expr,
+            input_fetch_exprs=input_fetch_exprs,
             is_first_level=is_first_level,
             **self.code_variables))
 
@@ -972,10 +1072,10 @@ class GenericScanKernel(object):
         # can scan at most one interval
         assert interval_size >= num_intervals
 
-        scan2_args = (interval_results, interval_results)
+        scan2_args = data_args + [interval_results, interval_results]
         if self.is_segmented:
             scan2_args = scan2_args + [first_segment_start_in_interval]
-        scan2_args = scan2_args + (num_intervals, interval_size)
+        scan2_args = scan2_args + [num_intervals, interval_size]
 
         l2_info.kernel(
                 queue, (1,), (l1_info.wg_size,),
@@ -1053,7 +1153,8 @@ class ExclusiveScanKernel(_ScanKernelBase):
 # {{{ higher-level trickery
 
 @context_dependent_memoize
-def _get_copy_if_kernel(ctx, dtype, predicate, scan_dtype, extra_args_types):
+def _get_copy_if_kernel(ctx, dtype, predicate, scan_dtype,
+        extra_args_types, preamble):
     ctype = dtype_to_ctype(dtype)
     arguments = [
         "__global %s *ary" % ctype,
@@ -1071,10 +1172,10 @@ def _get_copy_if_kernel(ctx, dtype, predicate, scan_dtype, extra_args_types):
             output_statement="""
                 if (prev_item != item) out[item-1] = ary[i];
                 if (i+1 == N) *count = item;
-                """
-            )
+                """,
+            preamble=preamble)
 
-def copy_if(ary, predicate, extra_args=[], queue=None):
+def copy_if(ary, predicate, extra_args=[], queue=None, preamble=""):
     """
     :arg extra_args: a list of tuples *(name, value)* specifying extra
         arguments to pass to the scan procedure.
@@ -1087,17 +1188,20 @@ def copy_if(ary, predicate, extra_args=[], queue=None):
     extra_args_types = tuple((name, val.dtype) for name, val in extra_args)
     extra_args_values = tuple(val for name, val in extra_args)
 
-    knl = _get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype, extra_args_types)
+    knl = _get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype,
+            extra_args_types, preamble=preamble)
     out = cl_array.empty_like(ary)
     count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64)
     knl(ary, out, count, *extra_args_values, queue=queue)
     return out, count
 
-def remove_if(ary, predicate, extra_args=[], queue=None):
-    return copy_if(ary, "!(%s)" % predicate, extra_args=extra_args, queue=queue)
+def remove_if(ary, predicate, extra_args=[], queue=None, preamble=""):
+    return copy_if(ary, "!(%s)" % predicate, extra_args=extra_args, queue=queue,
+            preamble=preamble)
 
 @context_dependent_memoize
-def _get_partition_kernel(ctx, dtype, predicate, scan_dtype, extra_args_types):
+def _get_partition_kernel(ctx, dtype, predicate, scan_dtype,
+        extra_args_types, preamble):
     ctype = dtype_to_ctype(dtype)
     arguments = [
         "__global %s *ary" % ctype,
@@ -1119,10 +1223,10 @@ def _get_partition_kernel(ctx, dtype, predicate, scan_dtype, extra_args_types):
                 else
                     out_false[i-item] = ary[i];
                 if (i+1 == N) *count_true = item;
-                """
-            )
+                """,
+            preamble=preamble)
 
-def partition(ary, predicate, extra_args=[], queue=None):
+def partition(ary, predicate, extra_args=[], queue=None, preamble=""):
     """
     :arg extra_args: a list of tuples *(name, value)* specifying extra
         arguments to pass to the scan procedure.
@@ -1135,7 +1239,8 @@ def partition(ary, predicate, extra_args=[], queue=None):
     extra_args_types = tuple((name, val.dtype) for name, val in extra_args)
     extra_args_values = tuple(val for name, val in extra_args)
 
-    knl = _get_partition_kernel(ary.context, ary.dtype, predicate, scan_dtype, extra_args_types)
+    knl = _get_partition_kernel(ary.context, ary.dtype, predicate, scan_dtype,
+            extra_args_types, preamble)
     out_true = cl_array.empty_like(ary)
     out_false = cl_array.empty_like(ary)
     count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64)
@@ -1143,35 +1248,53 @@ def partition(ary, predicate, extra_args=[], queue=None):
     return out_true, out_false, count
 
 @context_dependent_memoize
-def _get_unique_by_key_kernel(ctx, dtype, key_expr, scan_dtype, extra_args_types):
+def _get_unique_by_key_kernel(ctx, dtype, key_expr, scan_dtype,
+        extra_args_types, preamble):
     ctype = dtype_to_ctype(dtype)
     arguments = [
         "__global %s *ary" % ctype,
         "__global %s *out" % ctype,
-        "__global unsigned long *count_true",
+        "__global unsigned long *count_unique",
         ] + [
                 "%s %s" % (dtype_to_ctype(arg_dtype), name)
                 for name, arg_dtype in extra_args_types]
 
+    key_expr_define = "#define KEY_EXPR(a) %s\n" % key_expr.replace("\n", " ")
     return GenericScanKernel(
             ctx, dtype,
             arguments=", ".join(arguments),
-            input_expr="(%s) ? 1 : 0" % key_expr,
+            input_fetch_exprs=[
+                ("ary_im1", "ary", -1),
+                ("ary_i", "ary", 0),
+                ],
+            input_expr="(i == 0 || KEY_EXPR(ary_im1) != KEY_EXPR(ary_i)) ? 1 : 0",
             scan_expr="a+b", neutral="0",
             output_statement="""
-                if (prev_item != item)
-                    out_true[item-1] = ary[i];
-                else
-                    out_false[i-item] = ary[i];
-                if (i+1 == N) *count_true = item;
-                """)
+                if (prev_item != item) out[item-1] = ary[i];
+                if (i+1 == N) *count_unique = item;
+                """,
+            preamble=preamble+"\n\n"+key_expr_define)
 
-def unique_by_key(array, key_expr, **kwargs):
+def unique_by_key(ary, key_expr="a", extra_args=[], queue=None, preamble=""):
     """
     :arg extra_args: a list of tuples *(name, value)* specifying extra
         arguments to pass to the scan procedure.
     """
 
+    if len(ary) > np.iinfo(np.uint32):
+        scan_dtype = np.uint64
+    else:
+        scan_dtype = np.uint32
+
+    extra_args_types = tuple((name, val.dtype) for name, val in extra_args)
+    extra_args_values = tuple(val for name, val in extra_args)
+
+    knl = _get_unique_by_key_kernel(ary.context, ary.dtype, key_expr, scan_dtype,
+            extra_args_types, preamble)
+    out = cl_array.empty_like(ary)
+    count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64)
+    knl(ary, out, count, *extra_args_values, queue=queue)
+    return out, count
 
 # }}}
 
diff --git a/test/test_array.py b/test/test_array.py
index 32221b69c5c53d8a9d768614e15c66eb4b549de3..a1cc402507974ebe13ba715541087da2fee36add 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -801,6 +801,27 @@ def test_partition(ctx_factory):
         assert (true_dev.get()[:count_true_dev] == true_host).all()
         assert (false_dev.get()[:n-count_true_dev] == false_host).all()
 
+@pytools.test.mark_test.opencl
+def test_unique(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    from pyopencl.clrandom import rand as clrand
+    for n in scan_test_counts:
+        a_dev = clrand(queue, (n,), dtype=np.int32, a=0, b=1000)
+        a = a_dev.get()
+        a = np.sort(a)
+        a_dev = cl_array.to_device(queue, a)
+
+        a_unique_host = np.unique(a)
+
+        from pyopencl.scan import unique_by_key
+        a_unique_dev, count_unique_dev = unique_by_key(a_dev)
+
+        count_unique_dev = count_unique_dev.get()
+
+        assert (a_unique_dev.get()[:count_unique_dev] == a_unique_host).all()
+
 @pytools.test.mark_test.opencl
 def test_stride_preservation(ctx_factory):
     context = ctx_factory()