Skip to content
Snippets Groups Projects
Commit 7f06cf21 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Introduce work group size floor in scan.

parent 82936bbf
No related branches found
No related tags found
No related merge requests found
...@@ -975,7 +975,6 @@ class GenericScanKernel(_GenericScanKernelBase): ...@@ -975,7 +975,6 @@ class GenericScanKernel(_GenericScanKernelBase):
trip_count = 0 trip_count = 0
if self.devices[0].type == cl.device_type.CPU: if self.devices[0].type == cl.device_type.CPU:
# (about the widest vector a CPU can support, also taking # (about the widest vector a CPU can support, also taking
# into account that CPUs don't hide latency by large work groups # into account that CPUs don't hide latency by large work groups
...@@ -1002,6 +1001,19 @@ class GenericScanKernel(_GenericScanKernelBase): ...@@ -1002,6 +1001,19 @@ class GenericScanKernel(_GenericScanKernelBase):
wg_size, k_group_size) + 256 <= avail_local_mem): wg_size, k_group_size) + 256 <= avail_local_mem):
solutions.append((wg_size*k_group_size, k_group_size, wg_size)) solutions.append((wg_size*k_group_size, k_group_size, wg_size))
if self.devices[0].type == cl.device_type.GPU:
from pytools import any
for wg_size_floor in [256, 192, 128]:
have_sol_above_floor = any(wg_size >= wg_size_floor
for _, _, wg_size in solutions)
if have_sol_above_floor:
# delete all the others
solutions = [(total, k_group_size, wg_size)
for total, k_group_size, wg_size in solutions
if wg_size >= wg_size_floor]
break
_, k_group_size, max_scan_wg_size = max(solutions) _, k_group_size, max_scan_wg_size = max(solutions)
while True: while True:
...@@ -1140,7 +1152,6 @@ class GenericScanKernel(_GenericScanKernelBase): ...@@ -1140,7 +1152,6 @@ class GenericScanKernel(_GenericScanKernelBase):
wg_size = _round_down_to_power_of_2( wg_size = _round_down_to_power_of_2(
min(max_wg_size, 128)) min(max_wg_size, 128))
scan_tpl = _make_template(SCAN_INTERVALS_SOURCE) scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
scan_src = str(scan_tpl.render( scan_src = str(scan_tpl.render(
wg_size=wg_size, wg_size=wg_size,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment