diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 7562f485cfb2809ad9dc1fc4cdbec45c269c97d8..0db3a2187563e3df31d1eacb659711baebcedb10 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -678,6 +678,9 @@ class GenericScanKernel(object):
             as *i*. *prev_item* is unavailable when using exclusive scan.
             *prev_item* in a segmented scan will be the neutral element
             at a segment boundary, not the immediately preceding item.
+
+            Note that *prev_item enables the construction of an exclusive
+            scan.
         :arg is_i_segment_start_expr: If given, makes the scan a segmented
             scan. Has access to the current index `i` and the input element
             as `a` and returns a bool. If it returns true, then previous
@@ -1050,11 +1053,19 @@ class ExclusiveScanKernel(_ScanKernelBase):
 # {{{ higher-level trickery
 
 @context_dependent_memoize
-def get_copy_if_kernel(ctx, dtype, predicate, scan_dtype):
+def _get_copy_if_kernel(ctx, dtype, predicate, scan_dtype, extra_args_types):
     ctype = dtype_to_ctype(dtype)
+    arguments = [
+        "__global %s *ary" % ctype,
+        "__global %s *out" % ctype,
+        "__global unsigned long *count",
+        ] + [
+                "%s %s" % (dtype_to_ctype(arg_dtype), name)
+                for name, arg_dtype in extra_args_types]
+
     return GenericScanKernel(
             ctx, dtype,
-            arguments="__global %s *ary, __global %s *out, __global unsigned long *count" % (ctype, ctype),
+            arguments=", ".join(arguments),
             input_expr="(%s) ? 1 : 0" % predicate,
             scan_expr="a+b", neutral="0",
             output_statement="""
@@ -1063,26 +1074,104 @@ def get_copy_if_kernel(ctx, dtype, predicate, scan_dtype):
                 """
             )
 
-def copy_if(ary, predicate, queue=None):
+def copy_if(ary, predicate, extra_args=[], queue=None):
+    """
+    :arg extra_args: a list of tuples *(name, value)* specifying extra
+        arguments to pass to the scan procedure.
+    """
     if len(ary) > np.iinfo(np.uint32):
         scan_dtype = np.uint64
     else:
         scan_dtype = np.uint32
 
-    knl = get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype)
+    extra_args_types = tuple((name, val.dtype) for name, val in extra_args)
+    extra_args_values = tuple(val for name, val in extra_args)
+
+    knl = _get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype, extra_args_types)
     out = cl_array.empty_like(ary)
     count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64)
-    knl(ary, out, count, queue=queue)
+    knl(ary, out, count, *extra_args_values, queue=queue)
     return out, count
 
-def remove_if(array, predicate, **kwargs):
-    pass
+def remove_if(ary, predicate, extra_args=[], queue=None):
+    return copy_if(ary, "!(%s)" % predicate, extra_args=extra_args, queue=queue)
 
-def partition(array, predicate):
-    pass
+@context_dependent_memoize
+def _get_partition_kernel(ctx, dtype, predicate, scan_dtype, extra_args_types):
+    ctype = dtype_to_ctype(dtype)
+    arguments = [
+        "__global %s *ary" % ctype,
+        "__global %s *out_true" % ctype,
+        "__global %s *out_false" % ctype,
+        "__global unsigned long *count_true",
+        ] + [
+                "%s %s" % (dtype_to_ctype(arg_dtype), name)
+                for name, arg_dtype in extra_args_types]
+
+    return GenericScanKernel(
+            ctx, dtype,
+            arguments=", ".join(arguments),
+            input_expr="(%s) ? 1 : 0" % predicate,
+            scan_expr="a+b", neutral="0",
+            output_statement="""
+                if (prev_item != item)
+                    out_true[item-1] = ary[i];
+                else
+                    out_false[i-item] = ary[i];
+                if (i+1 == N) *count_true = item;
+                """
+            )
+
+def partition(ary, predicate, extra_args=[], queue=None):
+    """
+    :arg extra_args: a list of tuples *(name, value)* specifying extra
+        arguments to pass to the scan procedure.
+    """
+    if len(ary) > np.iinfo(np.uint32):
+        scan_dtype = np.uint64
+    else:
+        scan_dtype = np.uint32
+
+    extra_args_types = tuple((name, val.dtype) for name, val in extra_args)
+    extra_args_values = tuple(val for name, val in extra_args)
+
+    knl = _get_partition_kernel(ary.context, ary.dtype, predicate, scan_dtype, extra_args_types)
+    out_true = cl_array.empty_like(ary)
+    out_false = cl_array.empty_like(ary)
+    count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64)
+    knl(ary, out_true, out_false, count, *extra_args_values, queue=queue)
+    return out_true, out_false, count
+
+@context_dependent_memoize
+def _get_unique_by_key_kernel(ctx, dtype, key_expr, scan_dtype, extra_args_types):
+    ctype = dtype_to_ctype(dtype)
+    arguments = [
+        "__global %s *ary" % ctype,
+        "__global %s *out" % ctype,
+        "__global unsigned long *count_true",
+        ] + [
+                "%s %s" % (dtype_to_ctype(arg_dtype), name)
+                for name, arg_dtype in extra_args_types]
+
+    return GenericScanKernel(
+            ctx, dtype,
+            arguments=", ".join(arguments),
+            input_expr="(%s) ? 1 : 0" % key_expr,
+            scan_expr="a+b", neutral="0",
+            output_statement="""
+                if (prev_item != item)
+                    out_true[item-1] = ary[i];
+                else
+                    out_false[i-item] = ary[i];
+                if (i+1 == N) *count_true = item;
+                """)
+
+def unique_by_key(array, key_expr, **kwargs):
+    """
+    :arg extra_args: a list of tuples *(name, value)* specifying extra
+        arguments to pass to the scan procedure.
+    """
 
-def unique_by_key(array, key="", **kwargs):
-    pass
 
 # }}}
 
diff --git a/test/test_array.py b/test/test_array.py
index a540971811a4de9c92f97d950801da9202bf7f70..32221b69c5c53d8a9d768614e15c66eb4b549de3 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -773,11 +773,34 @@ def test_copy_if(ctx_factory):
 
         from pyopencl.scan import copy_if
 
-        selected = a[a>300]
-        selected_dev, count_dev = copy_if(a_dev, "ary[i] > 300")
+        crit = a_dev.dtype.type(300)
+        selected = a[a>crit]
+        selected_dev, count_dev = copy_if(a_dev, "ary[i] > myval", [("myval", crit)])
 
         assert (selected_dev.get()[:count_dev.get()] == selected).all()
 
+@pytools.test.mark_test.opencl
+def test_partition(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    from pyopencl.clrandom import rand as clrand
+    for n in scan_test_counts:
+        a_dev = clrand(queue, (n,), dtype=np.int32, a=0, b=1000)
+        a = a_dev.get()
+
+        crit = a_dev.dtype.type(300)
+        true_host = a[a>crit]
+        false_host = a[a<=crit]
+
+        from pyopencl.scan import partition
+        true_dev, false_dev, count_true_dev = partition(a_dev, "ary[i] > myval", [("myval", crit)])
+
+        count_true_dev = count_true_dev.get()
+
+        assert (true_dev.get()[:count_true_dev] == true_host).all()
+        assert (false_dev.get()[:n-count_true_dev] == false_host).all()
+
 @pytools.test.mark_test.opencl
 def test_stride_preservation(ctx_factory):
     context = ctx_factory()