From f3d868a83dd5a39b61da84d82e87b96f6eb06a12 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 5 Nov 2015 00:48:43 -0600 Subject: [PATCH] =?UTF-8?q?Fix=20global=20offset=20in=20kernel=20enqueue?= =?UTF-8?q?=20(reported=20by=20Guilherme=20Gon=C3=A7alves=20Ferrari)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyopencl/cffi_cl.py | 9 +++++++-- test/test_wrapper.py | 26 ++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/pyopencl/cffi_cl.py b/pyopencl/cffi_cl.py index 5194f0b1..ecb42397 100644 --- a/pyopencl/cffi_cl.py +++ b/pyopencl/cffi_cl.py @@ -1187,10 +1187,15 @@ def enqueue_nd_range_kernel(queue, kernel, global_work_size, local_work_size, global_work_size[i] * local_work_size[i] for i in range(work_dim)) + c_global_work_offset = _ffi.NULL if global_work_offset is not None: - raise NotImplementedError("global_work_offset") + if work_dim != len(global_work_offset): + raise RuntimeError("global work size and offset have differing " + "dimensions", status_code.INVALID_VALUE, + "enqueue_nd_range_kernel") + + c_global_work_offset = global_work_offset - c_global_work_offset = _ffi.NULL if local_work_size is None: local_work_size = _ffi.NULL diff --git a/test/test_wrapper.py b/test/test_wrapper.py index ab833c38..a741e087 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -837,6 +837,32 @@ def test_event_set_callback(ctx_factory): assert got_called +def test_global_offset(ctx_factory): + context = ctx_factory() + queue = cl.CommandQueue(context) + + prg = cl.Program(context, """ + __kernel void mult(__global float *a) + { a[get_global_id(0)] *= 2; } + """).build() + + n = 50 + a = np.random.rand(n).astype(np.float32) + + queue = cl.CommandQueue(context) + mf = cl.mem_flags + a_buf = cl.Buffer(context, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=a) + + step = 10 + for ofs in range(0, n, step): + prg.mult(queue, (step,), None, a_buf, global_offset=(ofs,)) + + a_2 = np.empty_like(a) + cl.enqueue_copy(queue, a_2, a_buf) + + assert (a_2 == 2*a).all() + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests. import pyopencl # noqa -- GitLab