diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 526f679c8f46fc0b7a782b33f5f32a21933a375e..03bfef17e1bc9b51f442aac9b8863f439efedb72 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -988,8 +988,10 @@ class GenericScanKernel(_GenericScanKernelBase):
             # (about the widest vector a CPU can support, also taking
             # into account that CPUs don't hide latency by large work groups
             max_scan_wg_size = 16
+            wg_size_multiples = 16
         else:
             max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices)
+            wg_size_multiples = 64
 
         avail_local_mem = min(
                 dev.local_mem_size
@@ -998,14 +1000,17 @@ class GenericScanKernel(_GenericScanKernelBase):
         # k_group_size should be a power of two because of in-kernel
         # division by that number.
 
-        k_group_size = 128
+        solutions = []
+        for k_exp in range(0, 7):
+            for wg_size in range(wg_size_multiples, max_scan_wg_size+1,
+                    wg_size_multiples):
 
-        while (
-                self.get_local_mem_use(
-                    max_scan_wg_size, k_group_size) + 256  > avail_local_mem):
-            k_group_size //= 2
+                k_group_size = 2**k_exp
+                if (self.get_local_mem_use(
+                    wg_size, k_group_size) + 256  <= avail_local_mem):
+                    solutions.append((wg_size*k_group_size, k_group_size, wg_size))
 
-        assert k_group_size > 1
+        _, k_group_size, max_scan_wg_size = max(solutions)
 
         while True:
             candidate_scan_info = self.build_scan_kernel(