diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 5f51fa14b568e5fa739fa6b3716b8f5becf01b20..91cff1e33238149201ca59e48d2017bfbee62484 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -950,15 +950,33 @@ class GenericScanKernel(_GenericScanKernelBase):
         use_lookbehind_update = "prev_item" in self.output_statement
         self.store_segment_start_flags = self.is_segmented and use_lookbehind_update
 
-        # {{{ loop to find usable workgroup size, build first-level scan
+        # {{{ find usable workgroup/k-group size, build first-level scan
 
         trip_count = 0
 
-        max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices)
 
         if self.devices[0].type == cl.device_type.CPU:
-            # (about the widest vector a CPU can support)
+            # (about the widest vector a CPU can support, also taking
+            # into account that CPUs don't hide latency by large work groups
             max_scan_wg_size = 16
+        else:
+            max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices)
+
+        avail_local_mem = min(
+                dev.local_mem_size
+                for dev in self.devices)
+
+        # k_group_size should be a power of two because of in-kernel
+        # division by that number.
+
+        k_group_size = 128
+
+        while (
+                self.get_local_mem_use(
+                    max_scan_wg_size, k_group_size) + 256  > avail_local_mem):
+            k_group_size //= 2
+
+        assert k_group_size > 1
 
         while True:
             candidate_scan_info = self.build_scan_kernel(
@@ -966,7 +984,8 @@ class GenericScanKernel(_GenericScanKernelBase):
                     self.is_segment_start_expr,
                     input_fetch_exprs=self.input_fetch_exprs,
                     is_first_level=True,
-                    store_segment_start_flags=self.store_segment_start_flags)
+                    store_segment_start_flags=self.store_segment_start_flags,
+                    k_group_size=k_group_size)
 
             # Will this device actually let us execute this kernel
             # at the desired work group size? Building it is the
@@ -980,10 +999,10 @@ class GenericScanKernel(_GenericScanKernelBase):
             if candidate_scan_info.wg_size <= kernel_max_wg_size:
                 break
             else:
-                max_scan_wg_size = kernel_max_wg_size
+                max_scan_wg_size = min(kernel_max_wg_size, max_scan_wg_size)
 
             trip_count += 1
-            assert trip_count <= 2
+            assert trip_count <= 20
 
         self.first_level_scan_info = candidate_scan_info
         assert (_round_down_to_power_of_2(candidate_scan_info.wg_size)
@@ -1019,6 +1038,7 @@ class GenericScanKernel(_GenericScanKernelBase):
                 input_fetch_exprs=[],
                 is_first_level=False,
                 store_segment_start_flags=False,
+                k_group_size=k_group_size,
                 **second_level_build_kwargs)
 
         assert min(
@@ -1058,25 +1078,42 @@ class GenericScanKernel(_GenericScanKernelBase):
 
         # }}}
 
-    # {{{ scan kernel build
+    # {{{ scan kernel build/properties
+
+    def get_local_mem_use(self, k_group_size, wg_size):
+        arg_dtypes = {}
+        for arg in self.parsed_args:
+            arg_dtypes[arg.name] = arg.dtype
+
+        fetch_expr_offsets = {}
+        for name, arg_name, ife_offset in self.input_fetch_exprs:
+            fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
+
+        return (
+                # ldata
+                self.dtype.itemsize*(k_group_size+1)*(wg_size+1)
+
+                # l_segment_start_flags
+                + k_group_size*wg_size
+
+                # l_first_segment_start_in_subtree
+                + self.index_dtype.itemsize*wg_size
+
+                + k_group_size*wg_size*sum(
+                    self.arg_dtypes[arg_name]
+                    for arg_name, ife_offsets in fetch_expr_offsets.items()
+                    if -1 in ife_offsets or len(ife_offsets) > 1))
+
 
     def build_scan_kernel(self, max_wg_size, arguments, input_expr,
             is_segment_start_expr, input_fetch_exprs, is_first_level,
-            store_segment_start_flags):
+            store_segment_start_flags, k_group_size):
         scalar_arg_dtypes = _get_scalar_arg_dtypes(arguments)
 
         # Thrust says that 128 is big enough for GT200
         wg_size = _round_down_to_power_of_2(
                 min(max_wg_size, 128))
 
-        # k_group_size should be a power of two because of in-kernel
-        # division by that number.
-
-        # FIXME: guesswork
-        if self.devices[0].type == cl.device_type.CPU:
-            k_group_size = 128
-        else:
-            k_group_size = 8
 
         scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
         scan_src = str(scan_tpl.render(