diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 75e4a83e0dd904859b9d2633b9049dd5356598ef..7562f485cfb2809ad9dc1fc4cdbec45c269c97d8 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -36,11 +36,14 @@ import pyopencl.array as cl_array
 from pyopencl.tools import dtype_to_ctype, bitlog2
 import pyopencl._mymako as mako
 from pyopencl._cluda import CLUDA_PREAMBLE
+from pyopencl.tools import context_dependent_memoize
 
 
 
 
 
+# {{{ preamble
+
 SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL//
 #define WG_SIZE ${wg_size}
 
@@ -56,18 +59,59 @@ and thus index-based side computations cannot be meaningful. */
 
 ${preamble}
 
-typedef ${scan_type} scan_type;
+typedef ${scan_ctype} scan_type;
 typedef ${index_ctype} index_type;
 #define NO_SEG_BOUNDARY ${index_type_max}
-
 """
 
-
-
+# }}}
 
 # {{{ main scan code
 
-SCAN_INTERVALS_SOURCE = mako.template.Template(SHARED_PREAMBLE + """//CL//
+# Algorithm: Each work group is responsible for one contiguous
+# 'interval'. There are just enough intervals to fill all compute
+# units.  Intervals are split into 'units'. A unit is what gets
+# worked on in parallel by one work group.
+#
+# in index space:
+# interval > unit > local-parallel > k-group
+#
+# (Note that there is also a transpose in here: The data is read
+# with local ids along linear index order.)
+#
+# Each unit has two axes--the local-id axis and the k axis.
+#
+# unit 0:
+# | | | | | | | | | | ----> lid
+# | | | | | | | | | |
+# | | | | | | | | | |
+# | | | | | | | | | |
+# | | | | | | | | | |
+#
+# |
+# v k (fastest-moving in linear index)
+#
+# unit 1:
+# | | | | | | | | | | ----> lid
+# | | | | | | | | | |
+# | | | | | | | | | |
+# | | | | | | | | | |
+# | | | | | | | | | |
+#
+# |
+# v k (fastest-moving in linear index)
+#
+# ...
+#
+# At a device-global level, this is a three-phase algorithm, in
+# which first each interval does its local scan, then a scan
+# across intervals exchanges data globally, and the final update
+# adds the exchanged sums to each interval.
+#
+# Exclusive scan is realized by performing a right-shift inside
+# the final update.
+
+SCAN_INTERVALS_SOURCE = SHARED_PREAMBLE + """//CL//
 
 #define K ${k_group_size}
 
@@ -119,48 +163,6 @@ void ${name_prefix}_scan_intervals(
         %endif
 
         {
-            // Algorithm: Each work group is responsible for one contiguous
-            // 'interval'. There are just enough intervals to fill all compute
-            // units.  Intervals are split into 'units'. A unit is what gets
-            // worked on in parallel by one work group.
-
-            // in index space:
-            // interval > unit > local-parallel > k-group
-
-            // (Note that there is also a transpose in here: The data is read
-            // with local ids along linear index order.)
-
-            // Each unit has two axes--the local-id axis and the k axis.
-            //
-            // unit 0:
-            // | | | | | | | | | | ----> lid
-            // | | | | | | | | | |
-            // | | | | | | | | | |
-            // | | | | | | | | | |
-            // | | | | | | | | | |
-            //
-            // |
-            // v k (fastest-moving in linear index)
-
-            // unit 1:
-            // | | | | | | | | | | ----> lid
-            // | | | | | | | | | |
-            // | | | | | | | | | |
-            // | | | | | | | | | |
-            // | | | | | | | | | |
-            //
-            // |
-            // v k (fastest-moving in linear index)
-            //
-            // ...
-
-            // At a device-global level, this is a three-phase algorithm, in
-            // which first each interval does its local scan, then a scan
-            // across intervals exchanges data globally, and the final update
-            // adds the exchanged sums to each interval.
-
-            // Exclusive scan is realized by performing a right-shift inside
-            // the final update.
 
             // {{{ read a unit's worth of data from global
 
@@ -168,13 +170,13 @@ void ${name_prefix}_scan_intervals(
             {
                 const index_type offset = k*WG_SIZE + LID_0;
 
-                const index_type i = unit_base + offset;
+                const index_type read_i = unit_base + offset;
 
                 %if is_tail:
-                if (i < interval_end)
+                if (read_i < interval_end)
                 %endif
                 {
-                    ldata[offset % K][offset / K] = INPUT_EXPR(i);
+                    ldata[offset % K][offset / K] = INPUT_EXPR(read_i);
                 }
             }
 
@@ -219,12 +221,12 @@ void ${name_prefix}_scan_intervals(
                 %endif
                 {
                     scan_type tmp = ldata[k][LID_0];
-                    index_type i = unit_base + K*LID_0 + k;
+                    index_type seq_i = unit_base + K*LID_0 + k;
 
                     %if is_segmented:
-                    if (IS_SEG_START(i, tmp)
+                    if (IS_SEG_START(seq_i, tmp)
                     {
-                        first_segment_start_in_k_group = i;
+                        first_segment_start_in_k_group = seq_i;
                         sum = tmp;
                     }
                     else
@@ -250,9 +252,9 @@ void ${name_prefix}_scan_intervals(
 
             // This tree-based scan works as follows:
             // - Each work item adds the previous item to its current state
-            // - barrier sync
+            // - barrier
             // - Each work item adds in the item from two positions to the left
-            // - barrier sync
+            // - barrier
             // - Each work item adds in the item from four positions to the left
             // ...
             // At the end, each item has summed all prior items.
@@ -392,15 +394,17 @@ void ${name_prefix}_scan_intervals(
         %endif
     }
 }
-""", strict_undefined=True, disable_unicode=True)
+"""
 
 # }}}
 
-# {{{ inclusive update
+# {{{ local update
+
+# used for inclusive scan
 
-INCLUSIVE_UPDATE_SOURCE = mako.template.Template(SHARED_PREAMBLE + """//CL//
+LOCAL_UPDATE_SOURCE = SHARED_PREAMBLE + r"""//CL//
 
-#define OUTPUT_STMT(i, a) ${output_statement}
+#define OUTPUT_STMT(i, prev_item, item) { ${output_statement}; }
 
 KERNEL
 REQD_WG_SIZE(WG_SIZE, 1, 1)
@@ -415,8 +419,10 @@ void ${name_prefix}_final_update(
     %endif
     )
 {
-    if (GID_0 == 0)
-        return;
+    %if neutral is None:
+        if (GID_0 == 0)
+            return;
+    %endif
 
     const index_type interval_begin = interval_size * GID_0;
     const index_type interval_end = min(interval_begin + interval_size, N);
@@ -426,31 +432,40 @@ void ${name_prefix}_final_update(
     %endif
 
     // value to add to this segment
-    scan_type prev_group_sum = interval_results[GID_0 - 1];
+    scan_type prev_group_sum;
+    if (GID_0 == 0)
+        prev_group_sum = ${neutral};
+    else
+        prev_group_sum = interval_results[GID_0 - 1];
 
     for(index_type unit_base = interval_begin;
         unit_base < interval_end;
         unit_base += WG_SIZE)
     {
-        const index_type i = unit_base + LID_0;
+        const index_type update_i = unit_base + LID_0;
 
-        if(i < interval_end)
+        if(update_i < interval_end)
         {
-            scan_type val = partial_scan_buffer[i];
-            scan_type value = SCAN_EXPR(prev_group_sum, val);
-            OUTPUT_STMT(i, value)
+            scan_type partial_val = partial_scan_buffer[update_i];
+            scan_type value = SCAN_EXPR(prev_group_sum, partial_val);
+
+            // printf("i: %d pgs: %d pv: %d val: %d\n", update_i, prev_group_sum, partial_val, value);
+
+            OUTPUT_STMT(update_i, prev_item_unavailable_with_local_update, value);
         }
     }
 }
-""", strict_undefined=True, disable_unicode=True)
+"""
 
 # }}}
 
-# {{{ exclusive update
+# {{{ lookbehind update
+
+# used for exclusive scan or output_statements that request look-behind
 
-EXCLUSIVE_UPDATE_SOURCE = mako.template.Template(SHARED_PREAMBLE + """//CL//
+LOOKBEHIND_UPDATE_SOURCE = SHARED_PREAMBLE + """//CL//
 
-#define OUTPUT_STMT(i, a) ${output_statement}
+#define OUTPUT_STMT(i, prev_item, item) { ${output_statement}; }
 
 KERNEL
 REQD_WG_SIZE(WG_SIZE, 1, 1)
@@ -475,18 +490,18 @@ void ${name_prefix}_final_update(
     if(GID_0 != 0)
         carry = interval_results[GID_0 - 1];
 
-    scan_type value = carry; // (A)
+    scan_type prev_value = carry; // (A)
 
     for (index_type unit_base = interval_begin;
         unit_base < interval_end;
         unit_base += WG_SIZE)
     {
-        const index_type i = unit_base + LID_0;
+        const index_type update_i = unit_base + LID_0;
 
         // load a work group's worth of data
-        if (i < interval_end)
+        if (update_i < interval_end)
         {
-            scan_type tmp = partial_scan_buffer[i];
+            scan_type tmp = partial_scan_buffer[update_i];
             ldata[LID_0] = SCAN_EXPR(carry, tmp);
         }
 
@@ -494,36 +509,42 @@ void ${name_prefix}_final_update(
 
         // perform right shift
         if (LID_0 != 0)
-            value = ldata[LID_0 - 1];
+            prev_value = ldata[LID_0 - 1];
         /*
-        else 
-            value = carry (see (A)) OR last tail (see (B));
+        else
+            prev_value = carry (see (A)) OR last tail (see (B));
         */
 
         %if is_segmented:
         {
-            scan_type scan_item_at_i = INPUT_EXPR(i)
-            bool is_seg_start = IS_SEG_START(i, scan_item_at_i);
+            scan_type scan_item_at_i = INPUT_EXPR(update_i)
+            bool is_seg_start = IS_SEG_START(update_i, scan_item_at_i);
             if (is_seg_start)
-                value = ${neutral};
+                prev_value = ${neutral};
         }
         %endif
 
-        if (i < interval_end)
+        if (update_i < interval_end)
         {
-            OUTPUT_STMT(i, value)
+            scan_type value = ldata[LID_0];
+
+            OUTPUT_STMT(update_i, prev_value, value)
         }
 
         if(LID_0 == 0)
-            value = ldata[WG_SIZE - 1]; // (B)
+            prev_value = ldata[WG_SIZE - 1]; // (B)
 
         local_barrier();
     }
 }
-""", strict_undefined=True, disable_unicode=True)
+"""
 
 # }}}
 
+# {{{ driver
+
+# {{{ helpers
+
 def _round_down_to_power_of_2(val):
     result = 2**bitlog2(val)
     if result > val:
@@ -532,14 +553,6 @@ def _round_down_to_power_of_2(val):
     assert result <= val
     return result
 
-
-
-
-
-# {{{ driver
-
-# {{{ helpers
-
 def _parse_args(arguments):
     from pyopencl.tools import parse_c_arg
     return [parse_c_arg(arg) for arg in arguments.split(",")]
@@ -556,8 +569,83 @@ def _get_scalar_arg_dtypes(arg_types):
 
     return result
 
+_PREFIX_WORDS = set("""
+        ldata partial_scan_buffer global scan_offset
+        segment_start_in_k_group carry
+        g_first_segment_start_in_interval IS_SEG_START tmp Z
+        val l_first_segment_start_in_k_group unit_size
+        index_type interval_begin interval_size offset_end K
+        SCAN_EXPR do_update NO_SEG_BOUNDARY WG_SIZE
+        first_segment_start_in_k_group scan_type
+        segment_start_in_subtree offset interval_results interval_end
+        first_segment_start_in_subtree unit_base
+        first_segment_start_in_interval k INPUT_EXPR
+        prev_group_sum prev pv this add value n partial_val pgs OUTPUT_STMT
+        is_seg_start update_i scan_item_at_i seq_i read_i
+        """.split())
+
+_IGNORED_WORDS = set("""
+        typedef for endfor if void while endwhile endfor endif else const printf
+        None return bool
+        LID_2 LID_1 LID_0
+        LDIM_0 LDIM_1 LDIM_2
+        GDIM_0 GDIM_1 GDIM_2
+        GID_0 GID_1 GID_2
+        GLOBAL_MEM LOCAL_MEM_ARG WITHIN_KERNEL LOCAL_MEM KERNEL REQD_WG_SIZE
+        local_barrier
+        CLK_LOCAL_MEM_FENCE OPENCL EXTENSION
+        pragma __attribute__ __global __kernel __local
+        get_local_size get_local_id cl_khr_fp64 reqd_work_group_size
+        get_num_groups barrier get_group_id
+
+        _final_update _scan_intervals
+
+        positions all padded integer its previous write based true writes 0
+        has local worth scan_expr to read cannot not X items False bank
+        four beginning follows applicable sum item min each indices works side
+        scanning right summed relative used id out index avoid current state
+        boundary True across be This reads groups along Otherwise undetermined
+        store of times prior s update first regardless Each number because
+        array unit from segment conflicts two parallel 2 empty define direction
+        CL padding work tree bounds values and adds
+        scan is allowed thus it an as enable at in occur sequentially end no
+        storage data 1 largest may representable uses entry Y meaningful
+        computations interval At the left dimension know d
+        A load B group perform shift tail see last OR
+
+        is_tail is_first_level input_expr argument_signature preamble
+        double_support neutral output_statement index_type_max
+        k_group_size name_prefix is_segmented index_ctype scan_ctype
+        wg_size is_i_segment_start_expr
+
+        a b prev_item i prev_item_unavailable_with_local_update prev_value
+        N
+        """.split())
+
+def _make_template(s):
+    leftovers = set()
+
+    def replace_id(match):
+        # avoid name clashes with user code by adding 'psc_' prefix to
+        # identifiers.
+
+        word = match.group(1)
+        if word in _IGNORED_WORDS:
+            return word
+        elif word in _PREFIX_WORDS:
+            return "psc_"+word
+        else:
+            leftovers.add(word)
+            return word
+
+    import re
+    s = re.sub(r"\b([a-zA-Z0-9_]+)\b", replace_id, s)
 
+    if leftovers:
+        from warnings import warn
+        warn("leftover words in identifier prefixing:" + " ".join(leftovers))
 
+    return mako.template.Template(s, strict_undefined=True, disable_unicode=True)
 
 from pytools import Record
 class _ScanKernelInfo(Record):
@@ -565,10 +653,10 @@ class _ScanKernelInfo(Record):
 
 # }}}
 
-class _GenericScanKernelBase(object):
+class GenericScanKernel(object):
     def __init__(self, ctx, dtype,
-            arguments, scan_expr, input_expr, output_statement,
-            neutral=None, is_i_segment_start_expr=None,
+            arguments, input_expr, scan_expr, neutral, output_statement,
+            is_i_segment_start_expr=None,
             partial_scan_buffer_name=None,
             name_prefix="scan", options=[], preamble="", devices=None):
         """
@@ -585,8 +673,11 @@ class _GenericScanKernelBase(object):
             to each array entry when scan first touches it. *arguments*
             must be given if *input_expr* is given.
         :arg output_statement: a C statement that writes
-            the output of the scan. It has access to the scan result as `a`
-            and the current index as `i`.
+            the output of the scan. It has access to the scan result as *item*,
+            the preceding scan result item as *prev_item*, and the current index
+            as *i*. *prev_item* is unavailable when using exclusive scan.
+            *prev_item* in a segmented scan will be the neutral element
+            at a segment boundary, not the immediately preceding item.
         :arg is_i_segment_start_expr: If given, makes the scan a segmented
             scan. Has access to the current index `i` and the input element
             as `a` and returns a bool. If it returns true, then previous
@@ -594,6 +685,9 @@ class _GenericScanKernelBase(object):
 
         The first array in the argument list determines the size of the index
         space over which the scan is carried out.
+
+        All code fragments further have access to N, the number of elements
+        being processed in the scan.
         """
 
         if isinstance(self, ExclusiveScanKernel) and neutral is None:
@@ -601,7 +695,13 @@ class _GenericScanKernelBase(object):
 
         self.context = ctx
         dtype = self.dtype = np.dtype(dtype)
-        self.neutral = neutral
+
+        if neutral is None:
+            from warnings import warn
+            warn("not specifying 'neutral' is deprecated and will lead to "
+                    "wrong results if your scan is not in-place or your "
+                    "'output_statement' otherwise does something non-trivial",
+                    stacklevel=2)
 
         self.index_dtype = np.dtype(np.uint32)
 
@@ -618,8 +718,12 @@ class _GenericScanKernelBase(object):
                 if isinstance(arg, VectorArg)][0]
 
         self.is_segmented = is_i_segment_start_expr is not None
+        if self.is_segmented:
+            is_i_segment_start_expr = is_i_segment_start_expr.replace("\n", " ")
 
-        if self.is_segmented and self.is_exclusive:
+        use_lookbehind_update = "prev_item" in output_statement
+
+        if self.is_segmented and use_lookbehind_update:
             # The final update in segmented exclusive scan must be able to
             # reconstruct where the segment boundaries were, and therefore
             # can't overwrite any of the input.
@@ -642,10 +746,10 @@ class _GenericScanKernelBase(object):
             name_prefix=name_prefix,
             index_ctype=dtype_to_ctype(self.index_dtype),
             index_type_max=str(np.iinfo(self.index_dtype).max),
-            scan_type=dtype_to_ctype(dtype),
+            scan_ctype=dtype_to_ctype(dtype),
             is_segmented=self.is_segmented,
-            scan_expr=scan_expr,
-            neutral=neutral,
+            scan_expr=scan_expr.replace("\n", " "),
+            neutral=neutral.replace("\n", " "),
             double_support=all(
                 has_double_support(dev) for dev in devices),
             )
@@ -724,12 +828,18 @@ class _GenericScanKernelBase(object):
 
         self.update_wg_size = min(max_scan_wg_size, 256)
 
-        final_update_src = str(self.final_update_tp.render(
+        if use_lookbehind_update:
+            update_src = LOOKBEHIND_UPDATE_SOURCE
+        else:
+            update_src = LOCAL_UPDATE_SOURCE
+
+        final_update_tpl = _make_template(update_src)
+        final_update_src = str(final_update_tpl.render(
             wg_size=self.update_wg_size,
-            output_statement=output_statement,
-            argument_signature=arguments,
+            output_statement=output_statement.replace("\n", " "),
+            argument_signature=arguments.replace("\n", " "),
             is_i_segment_start_expr=is_i_segment_start_expr,
-            input_expr=input_expr,
+            input_expr=input_expr.replace("\n", " "),
             **self.code_variables))
 
         final_update_prg = cl.Program(self.context, final_update_src).build(options)
@@ -759,11 +869,12 @@ class _GenericScanKernelBase(object):
         else:
             k_group_size = 8
 
-        scan_intervals_src = str(SCAN_INTERVALS_SOURCE.render(
+        scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
+        scan_intervals_src = str(scan_tpl.render(
             wg_size=wg_size,
             input_expr=input_expr,
             k_group_size=k_group_size,
-            argument_signature=arguments,
+            argument_signature=arguments.replace("\n", " "),
             is_i_segment_start_expr=is_i_segment_start_expr,
             is_first_level=is_first_level,
             **self.code_variables))
@@ -830,7 +941,7 @@ class _GenericScanKernelBase(object):
         interval_results = allocator(self.dtype.itemsize*num_intervals)
 
         if self.partial_scan_buffer_idx is None:
-            partial_scan_buffer = allocator(n)
+            partial_scan_buffer = allocator(n*self.dtype.itemsize)
         else:
             partial_scan_buffer = data_args[self.partial_scan_buffer_idx]
 
@@ -846,6 +957,11 @@ class _GenericScanKernelBase(object):
                 queue, (num_intervals,), (l1_info.wg_size,),
                 *scan1_args, **dict(g_times_l=True))
 
+        if 0:
+            psb_host = np.empty(n, self.dtype)
+            cl.enqueue_copy(queue, psb_host, partial_scan_buffer)
+            print "PSB", psb_host
+
         # }}}
 
         # {{{ second level inclusive scan of per-interval results
@@ -878,29 +994,21 @@ class _GenericScanKernelBase(object):
 
 # }}}
 
+# {{{ compatibility interface
 
-
-class GenericInclusiveScanKernel(_GenericScanKernelBase):
-    final_update_tp = INCLUSIVE_UPDATE_SOURCE
-    is_exclusive = False
-
-class GenericExclusiveScanKernel(_GenericScanKernelBase):
-    final_update_tp = EXCLUSIVE_UPDATE_SOURCE
-    is_exclusive = True
-
-class _ScanKernelBase(_GenericScanKernelBase):
+class _ScanKernelBase(GenericScanKernel):
     def __init__(self, ctx, dtype,
             scan_expr, neutral=None,
             name_prefix="scan", options=[], preamble="", devices=None):
         scan_ctype = dtype_to_ctype(dtype)
-        _GenericScanKernelBase.__init__(self,
+        GenericScanKernel.__init__(self,
                 ctx, dtype,
                 arguments="__global %s *input_ary, __global %s *output_ary" % (
                     scan_ctype, scan_ctype),
-                scan_expr=scan_expr,
                 input_expr="input_ary[i]",
-                output_statement="output_ary[i] = a;",
+                scan_expr=scan_expr,
                 neutral=neutral,
+                output_statement=self.ary_output_statement,
                 partial_scan_buffer_name="output_ary",
                 options=options, preamble=preamble, devices=devices)
 
@@ -926,17 +1034,56 @@ class _ScanKernelBase(_GenericScanKernelBase):
         if not n:
             return output_ary
 
-        _GenericScanKernelBase.__call__(self,
+        GenericScanKernel.__call__(self,
                 input_ary, output_ary, allocator=allocator, queue=queue)
 
         return output_ary
 
 class InclusiveScanKernel(_ScanKernelBase):
-    final_update_tp = INCLUSIVE_UPDATE_SOURCE
-    is_exclusive = False
+    ary_output_statement = "output_ary[i] = item;"
 
 class ExclusiveScanKernel(_ScanKernelBase):
-    final_update_tp = EXCLUSIVE_UPDATE_SOURCE
-    is_exclusive = True
+    ary_output_statement = "output_ary[i] = prev_item;"
+
+# }}}
+
+# {{{ higher-level trickery
+
+@context_dependent_memoize
+def get_copy_if_kernel(ctx, dtype, predicate, scan_dtype):
+    ctype = dtype_to_ctype(dtype)
+    return GenericScanKernel(
+            ctx, dtype,
+            arguments="__global %s *ary, __global %s *out, __global unsigned long *count" % (ctype, ctype),
+            input_expr="(%s) ? 1 : 0" % predicate,
+            scan_expr="a+b", neutral="0",
+            output_statement="""
+                if (prev_item != item) out[item-1] = ary[i];
+                if (i+1 == N) *count = item;
+                """
+            )
+
+def copy_if(ary, predicate, queue=None):
+    if len(ary) > np.iinfo(np.uint32):
+        scan_dtype = np.uint64
+    else:
+        scan_dtype = np.uint32
+
+    knl = get_copy_if_kernel(ary.context, ary.dtype, predicate, scan_dtype)
+    out = cl_array.empty_like(ary)
+    count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64)
+    knl(ary, out, count, queue=queue)
+    return out, count
+
+def remove_if(array, predicate, **kwargs):
+    pass
+
+def partition(array, predicate):
+    pass
+
+def unique_by_key(array, key="", **kwargs):
+    pass
+
+# }}}
 
 # vim: filetype=pyopencl:fdm=marker
diff --git a/test/test_array.py b/test/test_array.py
index 7b8bdf0b45d0f5f8d979e60fcdbe2ae14d223f86..c8ee6a003e27232a7f8b2ecd0158daf61c94bdd9 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -741,6 +741,17 @@ def summarize_error(obtained, desired, orig, thresh=1e-5):
 
     return " ".join(entries)
 
+scan_test_counts = [
+    10,
+    2 ** 10 - 5,
+    2 ** 10,
+    2 ** 10 + 5,
+    2 ** 20 - 2 ** 18,
+    2 ** 20 - 2 ** 18 + 5,
+    2 ** 20 + 1,
+    2 ** 20, 2 ** 24
+    ]
+
 @pytools.test.mark_test.opencl
 def test_scan(ctx_factory):
     context = ctx_factory()
@@ -755,16 +766,7 @@ def test_scan(ctx_factory):
             ]:
         knl = cls(context, dtype, "a+b", "0")
 
-        for n in [
-                10,
-                2 ** 10 - 5,
-                2 ** 10,
-                2 ** 10 + 5,
-                2 ** 20 - 2 ** 18,
-                2 ** 20 - 2 ** 18 + 5,
-                2 ** 20 + 1,
-                2 ** 20, 2 ** 24
-                ]:
+        for n in scan_test_counts:
 
             host_data = np.random.randint(0, 10, n).astype(dtype)
             dev_data = cl_array.to_device(queue, host_data)
@@ -778,7 +780,7 @@ def test_scan(ctx_factory):
                 desired_result -= host_data
 
             is_ok = (dev_data.get() == desired_result).all()
-            if 0 and not is_ok:
+            if 1 and not is_ok:
                 print(summarize_error(dev_data.get(), desired_result, host_data))
 
             print n, is_ok
@@ -786,6 +788,22 @@ def test_scan(ctx_factory):
             from gc import collect
             collect()
 
+@pytools.test.mark_test.opencl
+def test_copy_if(ctx_factory):
+    context = ctx_factory()
+    queue = cl.CommandQueue(context)
+
+    from pyopencl.clrandom import rand as clrand
+    for n in scan_test_counts:
+        a_dev = clrand(queue, (n,), dtype=np.int32, a=0, b=1000)
+        a = a_dev.get()
+
+        from pyopencl.scan import copy_if
+
+        selected = a[a>300]
+        selected_dev, count_dev = copy_if(a_dev, "ary[i] > 300")
+
+        assert (selected_dev.get()[:count_dev.get()] == selected).all()
 
 @pytools.test.mark_test.opencl
 def test_stride_preservation(ctx_factory):