diff --git a/pyopencl/scan.py b/pyopencl/scan.py
index 3a82172af5311658ea4f5b1c0c5dd75236e630cb..af428afe32f26f3f0c0bc54ebd649a849dd15a4d 100644
--- a/pyopencl/scan.py
+++ b/pyopencl/scan.py
@@ -975,7 +975,6 @@ class GenericScanKernel(_GenericScanKernelBase):
 
         trip_count = 0
 
-
         if self.devices[0].type == cl.device_type.CPU:
             # (about the widest vector a CPU can support, also taking
             # into account that CPUs don't hide latency by large work groups
@@ -1002,6 +1001,19 @@ class GenericScanKernel(_GenericScanKernelBase):
                     wg_size, k_group_size) + 256  <= avail_local_mem):
                     solutions.append((wg_size*k_group_size, k_group_size, wg_size))
 
+        if self.devices[0].type == cl.device_type.GPU:
+            from pytools import any
+            for wg_size_floor in [256, 192, 128]:
+                have_sol_above_floor = any(wg_size >= wg_size_floor
+                        for _, _, wg_size in solutions)
+
+                if have_sol_above_floor:
+                    # delete all the others
+                    solutions = [(total, k_group_size, wg_size)
+                            for total, k_group_size, wg_size in solutions
+                            if wg_size >= wg_size_floor]
+                    break
+
         _, k_group_size, max_scan_wg_size = max(solutions)
 
         while True:
@@ -1140,7 +1152,6 @@ class GenericScanKernel(_GenericScanKernelBase):
         wg_size = _round_down_to_power_of_2(
                 min(max_wg_size, 128))
 
-
         scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
         scan_src = str(scan_tpl.render(
             wg_size=wg_size,