diff --git a/pyopencl/array.py b/pyopencl/array.py
index 474a6bc70556668d67169188b16c710da82291da..c0c74ef6367223f1351f6d761605319090880137 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -37,8 +37,9 @@ import pyopencl as cl
 
 
 def splay(ctx, n):
-    min_work_items = 32
-    max_work_items = 128
+    max_work_items = max(dev.max_work_group_size for dev in ctx.devices)
+    max_work_items = min(128, max_work_items)
+    min_work_items = min(32, max_work_items)
     max_groups = max(
             4 * dev.max_compute_units * 8
             # 4 to overfill the device