diff --git a/pyopencl/cffi_cl.py b/pyopencl/cffi_cl.py index 2e6caa6109dfde15b788f34a87a1b1f06b1bd200..e5a724ba0bfa88df31f3ed5fd307e7a200cb7852 100644 --- a/pyopencl/cffi_cl.py +++ b/pyopencl/cffi_cl.py @@ -945,16 +945,21 @@ class Buffer(MemoryObject): size, flags)) sub_buf = self._create(_sub_buf[0]) MemoryObject.__init__(sub_buf, None) + return sub_buf def __getitem__(self, idx): - if not (idx.step == 1 or idx.stop is None): - raise RuntimeError("Buffer slice must have stride 1", + if not isinstance(idx, slice): + raise TypeError("buffer subscript must be a slice object") + + start, stop, stride = idx.indices(self.size) + if stride != 1: + raise ValueError("Buffer slice must have stride 1", status_code.INVALID_VALUE, "Buffer.__getitem__") - _ret = _ffi.new('clobj_t*') - _handle_error(_lib.buffer__getitem( - _ret, self.ptr, idx.start or 0, idx.stop or 0)) - ret = self._create(_ret[0]) - MemoryObject.__init__(ret, None) + + assert start <= stop + + size = stop - start + return self.get_sub_region(start, size) # }}} diff --git a/src/c_wrapper/buffer.cpp b/src/c_wrapper/buffer.cpp index bda4499d386a1e04eb3c6379dfbb1a947ee5496c..102e64fe19082411e4db6970c97e678bf28e2577 100644 --- a/src/c_wrapper/buffer.cpp +++ b/src/c_wrapper/buffer.cpp @@ -23,33 +23,6 @@ buffer::get_sub_region(size_t orig, size_t size, cl_mem_flags flags) const }); return new_buffer(mem); } - -PYOPENCL_USE_RESULT buffer* -buffer::getitem(ssize_t start, ssize_t end) const -{ - ssize_t length; - pyopencl_call_guarded(clGetMemObjectInfo, this, CL_MEM_SIZE, - size_arg(length), nullptr); - if (PYOPENCL_UNLIKELY(length <= 0)) - throw clerror("Buffer.__getitem__", CL_INVALID_VALUE, - "Cannot get the length of the buffer."); - if (end == 0 || end > length) { - end = length; - } else if (end < 0) { - end += length; - } - if (start < 0) { - start += length; - } - if (end <= start || start < 0) - throw clerror("Buffer.__getitem__", CL_INVALID_VALUE, - "Buffer slice should have end > start >= 0"); - cl_mem_flags flags; - pyopencl_call_guarded(clGetMemObjectInfo, this, CL_MEM_FLAGS, - size_arg(flags), nullptr); - flags &= ~CL_MEM_COPY_HOST_PTR; - return get_sub_region((size_t)start, (size_t)(end - start), flags); -} #endif // c wrapper diff --git a/src/c_wrapper/buffer.h b/src/c_wrapper/buffer.h index 23b990084295cb98b6e785f7531ccfc7f6ae3fcd..c97a7919b56e5fda3bec2e739520f21991cbc544 100644 --- a/src/c_wrapper/buffer.h +++ b/src/c_wrapper/buffer.h @@ -17,7 +17,6 @@ public: #if PYOPENCL_CL_VERSION >= 0x1010 PYOPENCL_USE_RESULT buffer *get_sub_region(size_t orig, size_t size, cl_mem_flags flags) const; - PYOPENCL_USE_RESULT buffer *getitem(ssize_t start, ssize_t end) const; #endif }; diff --git a/test/test_wrapper.py b/test/test_wrapper.py index a741e0873ae8171c1489061a3ad97f63dacf53b9..063071bafea6d69ca9a23f5ee10bc1f8e2320c5b 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -863,6 +863,34 @@ def test_global_offset(ctx_factory): assert (a_2 == 2*a).all() +def test_sub_buffers(ctx_factory): + ctx = ctx_factory() + if (ctx._get_cl_version() < (1, 1) and + cl.get_cl_header_version() < (1, 1)): + from pytest import skip + skip("sub-buffers are only available in OpenCL 1.1") + + alignment = ctx.devices[0].mem_base_addr_align + + queue = cl.CommandQueue(ctx) + + n = 30000 + a = (np.random.rand(n) * 100).astype(np.uint8) + + mf = cl.mem_flags + a_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=a) + + start = (5000 // alignment) * alignment + stop = start + 20 * alignment + + a_sub_ref = a[start:stop] + + a_sub = np.empty_like(a_sub_ref) + cl.enqueue_copy(queue, a_sub, a_buf[start:stop]) + + assert np.array_equal(a_sub, a_sub_ref) + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests. import pyopencl # noqa