diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 527ab76a78dbd8de6c433146781c03150838bf5e..6c22572128717394beb90a2c19752b52b58acf72 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -577,7 +577,7 @@ _PREFIX_WORDS = set("""
         index_type interval_begin interval_size offset_end K
         SCAN_EXPR do_update NO_SEG_BOUNDARY WG_SIZE
         first_segment_start_in_k_group scan_type
-        segment_start_in_subtree offset interval_results interval_end N
+        segment_start_in_subtree offset interval_results interval_end
         first_segment_start_in_subtree unit_base
         first_segment_start_in_interval k INPUT_EXPR
         prev_group_sum prev pv this add value n partial_val pgs OUTPUT_STMT
@@ -619,6 +619,7 @@ _IGNORED_WORDS = set("""
         wg_size is_i_segment_start_expr
 
         a b prev_item i prev_item_unavailable_with_local_update prev_value
+        N
         """.split())
 
 def _make_template(s):
@@ -684,6 +685,9 @@ class GenericScanKernel(object):
 
         The first array in the argument list determines the size of the index
         space over which the scan is carried out.
+
+        All code fragments further have access to N, the number of elements
+        being processed in the scan.
         """
 
         if isinstance(self, ExclusiveScanKernel) and neutral is None:
@@ -714,6 +718,9 @@ class GenericScanKernel(object):
                 if isinstance(arg, VectorArg)][0]
 
         self.is_segmented = is_i_segment_start_expr is not None
+        if self.is_segmented:
+            is_i_segment_start_expr = is_i_segment_start_expr.replace("\n", " ")
+
         use_lookbehind_update = "prev_item" in output_statement
 
         if self.is_segmented and use_lookbehind_update:
@@ -741,8 +748,8 @@ class GenericScanKernel(object):
             index_type_max=str(np.iinfo(self.index_dtype).max),
             scan_ctype=dtype_to_ctype(dtype),
             is_segmented=self.is_segmented,
-            scan_expr=scan_expr,
-            neutral=neutral,
+            scan_expr=scan_expr.replace("\n", " "),
+            neutral=neutral.replace("\n", " "),
             double_support=all(
                 has_double_support(dev) for dev in devices),
             )
@@ -829,10 +836,10 @@ class GenericScanKernel(object):
         final_update_tpl = _make_template(update_src)
         final_update_src = str(final_update_tpl.render(
             wg_size=self.update_wg_size,
-            output_statement=output_statement,
-            argument_signature=arguments,
+            output_statement=output_statement.replace("\n", " "),
+            argument_signature=arguments.replace("\n", " "),
             is_i_segment_start_expr=is_i_segment_start_expr,
-            input_expr=input_expr,
+            input_expr=input_expr.replace("\n", " "),
             **self.code_variables))
 
         final_update_prg = cl.Program(self.context, final_update_src).build(options)
@@ -867,7 +874,7 @@ class GenericScanKernel(object):
             wg_size=wg_size,
             input_expr=input_expr,
             k_group_size=k_group_size,
-            argument_signature=arguments,
+            argument_signature=arguments.replace("\n", " "),
             is_i_segment_start_expr=is_i_segment_start_expr,
             is_first_level=is_first_level,
             **self.code_variables))
@@ -1047,18 +1054,24 @@ def get_copy_if_kernel(ctx, dtype, predicate):
     ctype = dtype_to_ctype(dtype)
     return GenericScanKernel(
             ctx, np.uint32,
-            arguments="__global %s *ary, __global %s *out" % (ctype, ctype),
+            arguments="__global %s *ary, __global %s *out, __global unsigned long *count" % (ctype, ctype),
             input_expr="(%s) ? 1 : 0" % predicate,
             scan_expr="a+b", neutral="0",
-            output_statement="if (prev_item != item) out[item-1] = ary[i];"
+            output_statement="""
+                if (prev_item != item) out[item-1] = ary[i];
+                if (i+1 == N) *count = item;
+                """
             )
 
 def copy_if(ary, predicate, queue=None):
+    # FIXME use 64-bit scan, eventually
+    # (not relevant for 6GB GPUs)
+
     knl = get_copy_if_kernel(ary.context, ary.dtype, predicate)
-    # FIXME return count
     out = cl_array.empty_like(ary)
-    knl(ary, out, queue=queue)
-    return out
+    count = ary._new_with_changes(data=None, shape=(), strides=(), dtype=np.uint64)
+    knl(ary, out, count, queue=queue)
+    return out, count
 
 def remove_if(array, predicate, **kwargs):
     pass
diff --git a/test/test_array.py b/test/test_array.py
index 984334ba0ee9e57288077f12159806ae8bad2ce2..a540971811a4de9c92f97d950801da9202bf7f70 100644
--- a/test/test_array.py
+++ b/test/test_array.py
@@ -714,6 +714,17 @@ def summarize_error(obtained, desired, orig, thresh=1e-5):
 
     return " ".join(entries)
 
+scan_test_counts = [
+    10,
+    2 ** 10 - 5,
+    2 ** 10,
+    2 ** 10 + 5,
+    2 ** 20 - 2 ** 18,
+    2 ** 20 - 2 ** 18 + 5,
+    2 ** 20 + 1,
+    2 ** 20, 2 ** 24
+    ]
+
 @pytools.test.mark_test.opencl
 def test_scan(ctx_factory):
     context = ctx_factory()
@@ -728,16 +739,7 @@ def test_scan(ctx_factory):
             ]:
         knl = cls(context, dtype, "a+b", "0")
 
-        for n in [
-                10,
-                2 ** 10 - 5,
-                2 ** 10,
-                2 ** 10 + 5,
-                2 ** 20 - 2 ** 18,
-                2 ** 20 - 2 ** 18 + 5,
-                2 ** 20 + 1,
-                2 ** 20, 2 ** 24
-                ]:
+        for n in scan_test_counts:
 
             host_data = np.random.randint(0, 10, n).astype(dtype)
             dev_data = cl_array.to_device(queue, host_data)
@@ -765,15 +767,16 @@ def test_copy_if(ctx_factory):
     queue = cl.CommandQueue(context)
 
     from pyopencl.clrandom import rand as clrand
-    a_dev = clrand(queue, (200000,), dtype=np.int32, a=0, b=1000)
-    a = a_dev.get()
+    for n in scan_test_counts:
+        a_dev = clrand(queue, (n,), dtype=np.int32, a=0, b=1000)
+        a = a_dev.get()
 
-    from pyopencl.scan import copy_if
+        from pyopencl.scan import copy_if
 
-    selected = a[a>300]
-    selected_dev = copy_if(a_dev, "ary[i] > 300").get()[:len(selected)]
+        selected = a[a>300]
+        selected_dev, count_dev = copy_if(a_dev, "ary[i] > 300")
 
-    assert (selected_dev == selected).all()
+        assert (selected_dev.get()[:count_dev.get()] == selected).all()
 
 @pytools.test.mark_test.opencl
 def test_stride_preservation(ctx_factory):