diff --git a/pyopencl/bitonic_sort.py b/pyopencl/bitonic_sort.py index 8e2b4045f87769629dde18ebfe41c2e9dc4d645d..d4beaba2c7aee15c4cf82ba5fc2f07daed8bcebe 100644 --- a/pyopencl/bitonic_sort.py +++ b/pyopencl/bitonic_sort.py @@ -121,8 +121,8 @@ class BitonicSort(object): if aux: last_evt = knl( queue, (nt,), wg, arr.data, idx.data, - cl.LocalMemory(wg[0]*4*arr.dtype.itemsize), - cl.LocalMemory(wg[0]*4*idx.dtype.itemsize), + cl.LocalMemory(wg[0]*arr.dtype.itemsize), + cl.LocalMemory(wg[0]*idx.dtype.itemsize), wait_for=[last_evt]) for knl, nt, wg, _ in run_queue[1:]: last_evt = knl( @@ -184,9 +184,9 @@ class BitonicSort(object): available_lmem = dev.local_mem_size while True: - lmem_size = wg*4*key_dtype.itemsize + lmem_size = wg*key_dtype.itemsize if argsort: - lmem_size += wg*4*idx_dtype.itemsize + lmem_size += wg*idx_dtype.itemsize if lmem_size + 512 > available_lmem: wg //= 2