From 398fb31dc9ca295b85000ec5adc9dcf743b366f5 Mon Sep 17 00:00:00 2001 From: Yichao Yu <yyc1992@gmail.com> Date: Fri, 20 Jun 2014 21:51:20 +0800 Subject: [PATCH] Buffer.__getitem__ --- pyopencl/c_wrapper/wrap_cl_core.h | 1 + pyopencl/cffi_cl.py | 11 +++++- src/c_wrapper/buffer.cpp | 61 ++++++++++++++++++------------- src/c_wrapper/buffer.h | 1 + 4 files changed, 46 insertions(+), 28 deletions(-) diff --git a/pyopencl/c_wrapper/wrap_cl_core.h b/pyopencl/c_wrapper/wrap_cl_core.h index 9d760492..9292378a 100644 --- a/pyopencl/c_wrapper/wrap_cl_core.h +++ b/pyopencl/c_wrapper/wrap_cl_core.h @@ -78,6 +78,7 @@ error *create_buffer(clobj_t *buffer, clobj_t context, cl_mem_flags flags, size_t size, void *hostbuf); error *buffer__get_sub_region(clobj_t *_sub_buf, clobj_t _buf, size_t orig, size_t size, cl_mem_flags flags); +error *buffer__getitem(clobj_t *_ret, clobj_t _buf, ssize_t start, ssize_t end); // Memory Object error *memory_object__release(clobj_t obj); // Memory Map diff --git a/pyopencl/cffi_cl.py b/pyopencl/cffi_cl.py index d9cf40c2..cf6248e9 100644 --- a/pyopencl/cffi_cl.py +++ b/pyopencl/cffi_cl.py @@ -726,9 +726,16 @@ class Buffer(MemoryObject): size, flags)) sub_buf = self._create(_sub_buf[0]) MemoryObject.__init__(sub_buf, None) - sub_buf._handle_buf_flags(flags) - # TODO __getitem__ ? + def __getitem__(self, idx): + if not (idx.step == 1 or idx.stop is None): + raise RuntimeError("Buffer slice must have stride 1", + status_code.INVALID_VALUE, "Buffer.__getitem__") + _ret = _ffi.new('clobj_t*') + _handle_error(_lib.buffer__getitem( + _ret, self.ptr, idx.start or 0, idx.stop or 0)) + ret = self._create(_ret[0]) + MemoryObject.__init__(ret, None) # }}} diff --git a/src/c_wrapper/buffer.cpp b/src/c_wrapper/buffer.cpp index 36d10bb0..20741862 100644 --- a/src/c_wrapper/buffer.cpp +++ b/src/c_wrapper/buffer.cpp @@ -26,32 +26,32 @@ buffer::get_sub_region(size_t orig, size_t size, cl_mem_flags flags) const return new_buffer(mem); } -// buffer *getitem(py::slice slc) const -// { -// PYOPENCL_BUFFER_SIZE_T start, end, stride, length; - -// size_t my_length; -// PYOPENCL_CALL_GUARDED(clGetMemObjectInfo, -// (this, CL_MEM_SIZE, sizeof(my_length), &my_length, 0)); - -// #if PY_VERSION_HEX >= 0x03020000 -// if (PySlice_GetIndicesEx(slc.ptr(), -// #else -// if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject *>(slc.ptr()), -// #endif -// my_length, &start, &end, &stride, &length) != 0) -// throw py::error_already_set(); - -// if (stride != 1) -// throw clerror("Buffer.__getitem__", CL_INVALID_VALUE, -// "Buffer slice must have stride 1"); - -// cl_mem_flags my_flags; -// PYOPENCL_CALL_GUARDED(clGetMemObjectInfo, -// (this, CL_MEM_FLAGS, sizeof(my_flags), &my_flags, 0)); - -// return get_sub_region(start, end, my_flags); -// } +PYOPENCL_USE_RESULT buffer* +buffer::getitem(ssize_t start, ssize_t end) const +{ + ssize_t length; + pyopencl_call_guarded(clGetMemObjectInfo, this, CL_MEM_SIZE, + size_arg(length), nullptr); + if (PYOPENCL_UNLIKELY(length <= 0)) + throw clerror("Buffer.__getitem__", CL_INVALID_VALUE, + "Cannot get the length of the buffer."); + if (end == 0 || end > length) { + end = length; + } else if (end < 0) { + end += length; + } + if (start < 0) { + start += length; + } + if (end <= start || start < 0) + throw clerror("Buffer.__getitem__", CL_INVALID_VALUE, + "Buffer slice should have end > start >= 0"); + cl_mem_flags flags; + pyopencl_call_guarded(clGetMemObjectInfo, this, CL_MEM_FLAGS, + size_arg(flags), nullptr); + flags &= ~CL_MEM_COPY_HOST_PTR; + return get_sub_region((size_t)start, (size_t)(end - start), flags); +} #endif } @@ -245,4 +245,13 @@ buffer__get_sub_region(clobj_t *_sub_buf, clobj_t _buf, size_t orig, }); } +error* +buffer__getitem(clobj_t *_ret, clobj_t _buf, ssize_t start, ssize_t end) +{ + auto buf = static_cast<buffer*>(_buf); + return c_handle_error([&] { + *_ret = buf->getitem(start, end); + }); +} + #endif diff --git a/src/c_wrapper/buffer.h b/src/c_wrapper/buffer.h index 193c531b..95feffb8 100644 --- a/src/c_wrapper/buffer.h +++ b/src/c_wrapper/buffer.h @@ -19,6 +19,7 @@ public: #if PYOPENCL_CL_VERSION >= 0x1010 PYOPENCL_USE_RESULT buffer *get_sub_region(size_t orig, size_t size, cl_mem_flags flags) const; + PYOPENCL_USE_RESULT buffer *getitem(ssize_t start, ssize_t end) const; #endif }; -- GitLab