diff --git a/pyopencl/array.py b/pyopencl/array.py index 474a6bc70556668d67169188b16c710da82291da..c0c74ef6367223f1351f6d761605319090880137 100644 --- a/pyopencl/array.py +++ b/pyopencl/array.py @@ -37,8 +37,9 @@ import pyopencl as cl def splay(ctx, n): - min_work_items = 32 - max_work_items = 128 + max_work_items = max(dev.max_work_group_size for dev in ctx.devices) + max_work_items = min(128, max_work_items) + min_work_items = min(32, max_work_items) max_groups = max( 4 * dev.max_compute_units * 8 # 4 to overfill the device