Skip to content
Snippets Groups Projects
Commit ed655e92 authored by Marko Bencun's avatar Marko Bencun
Browse files

experimenting with the buffer protocol

parent a3729f15
No related branches found
No related tags found
No related merge requests found
#from pyopencl._cl import PooledBuffer, MemoryPool from pyopencl._cl import PooledBuffer, MemoryPool
from _cffi import _ffi, _lib from _cffi import _ffi, _lib
import warnings import warnings
import np
import ctypes
bitlog2 = _lib.bitlog2 bitlog2 = _lib.bitlog2
...@@ -53,7 +55,7 @@ for type_, d in _constants.iteritems(): ...@@ -53,7 +55,7 @@ for type_, d in _constants.iteritems():
# {{{ exceptions # {{{ exceptions
class Error(Exception): class Error(Exception):
def __init__(self, routine, code, msg=""): def __init__(self, msg='', routine='', code=0):
self.routine = routine self.routine = routine
self.code = code self.code = code
self.what = msg self.what = msg
...@@ -84,7 +86,7 @@ def _handle_error(error): ...@@ -84,7 +86,7 @@ def _handle_error(error):
klass = RuntimeError klass = RuntimeError
else: else:
klass = Error klass = Error
e = klass(_ffi.string(error.routine), error.code, _ffi.string(error.msg)) e = klass(routine=_ffi.string(error.routine), code=error.code, msg=_ffi.string(error.msg))
_lib._free(error.routine) _lib._free(error.routine)
_lib._free(error.msg) _lib._free(error.msg)
_lib._free(error) _lib._free(error)
...@@ -203,11 +205,30 @@ class MemoryObject(MemoryObjectHolder): ...@@ -203,11 +205,30 @@ class MemoryObject(MemoryObjectHolder):
pass pass
def _c_buffer_from_obj(obj): def _c_buffer_from_obj(obj, writable=False):
if obj is None: if obj is None:
return _ffi.NULL, 0 return _ffi.NULL, 0
# assume numpy array for now
return _ffi.cast('void *', obj.__array_interface__['data'][0]), obj.nbytes # numpy array
if hasattr(obj, '__array_interface__'): # if writeable == False, some tests fail...
return _ffi.cast('void *', obj.__array_interface__['data'][0]), obj.nbytes
# ... ?
# fall back to old CPython buffer protocol API
addr = ctypes.c_void_p()
length = ctypes.c_ssize_t()
try:
if writable:
status = ctypes.pythonapi.PyObject_AsWriteBuffer(ctypes.py_object(obj), ctypes.byref(addr), ctypes.byref(length))
else:
status = ctypes.pythonapi.PyObject_AsReadBuffer(ctypes.py_object(obj), ctypes.byref(addr), ctypes.byref(length))
except TypeError:
raise LogicError("", status_code.INVALID_VALUE, "PyOpencl does not accept bare Python types as arguments")
else:
if status:
raise Exception('TODO error_already_set')
return _ffi.cast('void *', addr.value), length.value
class Buffer(MemoryObject): class Buffer(MemoryObject):
_id = 'buffer' _id = 'buffer'
...@@ -216,7 +237,7 @@ class Buffer(MemoryObject): ...@@ -216,7 +237,7 @@ class Buffer(MemoryObject):
warnings.warn("'hostbuf' was passed, but no memory flags to make use of it.") warnings.warn("'hostbuf' was passed, but no memory flags to make use of it.")
c_hostbuf = _ffi.NULL c_hostbuf = _ffi.NULL
if hostbuf is not None: if hostbuf is not None:
c_hostbuf, hostbuf_size = _c_buffer_from_obj(hostbuf) c_hostbuf, hostbuf_size = _c_buffer_from_obj(hostbuf, writable=flags & mem_flags.USE_HOST_PTR)
if size > hostbuf_size: if size > hostbuf_size:
raise RuntimeError("Buffer", status_code.INVALID_VALUE, "specified size is greater than host buffer size") raise RuntimeError("Buffer", status_code.INVALID_VALUE, "specified size is greater than host buffer size")
if size == 0: if size == 0:
...@@ -446,7 +467,7 @@ def _c_obj_list(objs=None): ...@@ -446,7 +467,7 @@ def _c_obj_list(objs=None):
return _ffi.new('void *[]', [ev.ptr for ev in objs]), len(objs) return _ffi.new('void *[]', [ev.ptr for ev in objs]), len(objs)
def _enqueue_read_buffer(queue, mem, hostbuf, device_offset=0, wait_for=None, is_blocking=True): def _enqueue_read_buffer(queue, mem, hostbuf, device_offset=0, wait_for=None, is_blocking=True):
c_buf, size = _c_buffer_from_obj(hostbuf) c_buf, size = _c_buffer_from_obj(hostbuf, writable=True)
ptr_event = _ffi.new('void **') ptr_event = _ffi.new('void **')
c_wait_for, num_wait_for = _c_obj_list(wait_for) c_wait_for, num_wait_for = _c_obj_list(wait_for)
_handle_error(_lib._enqueue_read_buffer( _handle_error(_lib._enqueue_read_buffer(
...@@ -493,7 +514,7 @@ def _enqueue_write_buffer(queue, mem, hostbuf, device_offset=0, wait_for=None, i ...@@ -493,7 +514,7 @@ def _enqueue_write_buffer(queue, mem, hostbuf, device_offset=0, wait_for=None, i
return _create_instance(Event, ptr_event[0]) return _create_instance(Event, ptr_event[0])
def _enqueue_read_image(queue, mem, origin, region, hostbuf, row_pitch=0, slice_pitch=0, wait_for=None, is_blocking=True): def _enqueue_read_image(queue, mem, origin, region, hostbuf, row_pitch=0, slice_pitch=0, wait_for=None, is_blocking=True):
c_buf, size = _c_buffer_from_obj(hostbuf) c_buf, size = _c_buffer_from_obj(hostbuf, writable=True)
ptr_event = _ffi.new('void **') ptr_event = _ffi.new('void **')
c_wait_for, num_wait_for = _c_obj_list(wait_for) c_wait_for, num_wait_for = _c_obj_list(wait_for)
_handle_error(_lib._enqueue_read_image( _handle_error(_lib._enqueue_read_image(
...@@ -654,7 +675,7 @@ class Image(MemoryObject): ...@@ -654,7 +675,7 @@ class Image(MemoryObject):
if shape is None: if shape is None:
raise LogicError("Image", status_code.INVALID_VALUE, "'shape' must be given") raise LogicError("Image", status_code.INVALID_VALUE, "'shape' must be given")
c_buf, size = _c_buffer_from_obj(buffer) c_buf, size = _c_buffer_from_obj(buffer, writable=flags & mem_flags.USE_HOST_PTR)
dims = len(shape) dims = len(shape)
if dims == 2: if dims == 2:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment