diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 3a82172af5311658ea4f5b1c0c5dd75236e630cb..af428afe32f26f3f0c0bc54ebd649a849dd15a4d 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -975,7 +975,6 @@ class GenericScanKernel(_GenericScanKernelBase): trip_count = 0 - if self.devices[0].type == cl.device_type.CPU: # (about the widest vector a CPU can support, also taking # into account that CPUs don't hide latency by large work groups @@ -1002,6 +1001,19 @@ class GenericScanKernel(_GenericScanKernelBase): wg_size, k_group_size) + 256 <= avail_local_mem): solutions.append((wg_size*k_group_size, k_group_size, wg_size)) + if self.devices[0].type == cl.device_type.GPU: + from pytools import any + for wg_size_floor in [256, 192, 128]: + have_sol_above_floor = any(wg_size >= wg_size_floor + for _, _, wg_size in solutions) + + if have_sol_above_floor: + # delete all the others + solutions = [(total, k_group_size, wg_size) + for total, k_group_size, wg_size in solutions + if wg_size >= wg_size_floor] + break + _, k_group_size, max_scan_wg_size = max(solutions) while True: @@ -1140,7 +1152,6 @@ class GenericScanKernel(_GenericScanKernelBase): wg_size = _round_down_to_power_of_2( min(max_wg_size, 128)) - scan_tpl = _make_template(SCAN_INTERVALS_SOURCE) scan_src = str(scan_tpl.render( wg_size=wg_size,