diff --git a/pyopencl/scan.py b/pyopencl/scan.py index 526f679c8f46fc0b7a782b33f5f32a21933a375e..03bfef17e1bc9b51f442aac9b8863f439efedb72 100644 --- a/pyopencl/scan.py +++ b/pyopencl/scan.py @@ -988,8 +988,10 @@ class GenericScanKernel(_GenericScanKernelBase): # (about the widest vector a CPU can support, also taking # into account that CPUs don't hide latency by large work groups max_scan_wg_size = 16 + wg_size_multiples = 16 else: max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices) + wg_size_multiples = 64 avail_local_mem = min( dev.local_mem_size @@ -998,14 +1000,17 @@ class GenericScanKernel(_GenericScanKernelBase): # k_group_size should be a power of two because of in-kernel # division by that number. - k_group_size = 128 + solutions = [] + for k_exp in range(0, 7): + for wg_size in range(wg_size_multiples, max_scan_wg_size+1, + wg_size_multiples): - while ( - self.get_local_mem_use( - max_scan_wg_size, k_group_size) + 256 > avail_local_mem): - k_group_size //= 2 + k_group_size = 2**k_exp + if (self.get_local_mem_use( + wg_size, k_group_size) + 256 <= avail_local_mem): + solutions.append((wg_size*k_group_size, k_group_size, wg_size)) - assert k_group_size > 1 + _, k_group_size, max_scan_wg_size = max(solutions) while True: candidate_scan_info = self.build_scan_kernel(