diff --git a/pyopencl/cffi_cl.py b/pyopencl/cffi_cl.py index ea4d01a9e77b13be2ed9e973d3a13d9f93cc7520..d98f169d0813e5df6b692c8d7e91b3d0882e94b8 100644 --- a/pyopencl/cffi_cl.py +++ b/pyopencl/cffi_cl.py @@ -2187,8 +2187,11 @@ def enqueue_map_buffer(queue, buf, flags, offset, shape, dtype, _handle_error(_lib.enqueue_map_buffer(_event, _map, queue.ptr, buf.ptr, flags, offset, byte_size, c_wait_for, num_wait_for, bool(is_blocking))) - return (np.asarray(MemoryMap._create(_map[0], shape, dtype.str, strides)), - Event._create(_event[0])) + mmap = MemoryMap._create(_map[0], shape, dtype.str, strides) + ary = np.asarray(mmap) + ary.dtype = dtype + + return (ary, Event._create(_event[0])) def _enqueue_fill_buffer(queue, mem, pattern, offset, size, wait_for=None): @@ -2283,8 +2286,10 @@ def enqueue_map_image(queue, img, flags, origin, region, shape, dtype, flags, origin, origin_l, region, region_l, _row_pitch, _slice_pitch, c_wait_for, num_wait_for, is_blocking)) - return (np.asarray(MemoryMap._create(_map[0], shape, dtype.str, strides)), - Event._create(_event[0]), _row_pitch[0], _slice_pitch[0]) + mmap = MemoryMap._create(_map[0], shape, dtype.str, strides) + ary = np.asarray(mmap) + ary.dtype = dtype + return (ary, Event._create(_event[0]), _row_pitch[0], _slice_pitch[0]) def enqueue_fill_image(queue, img, color, origin, region, wait_for=None): diff --git a/test/test_wrapper.py b/test/test_wrapper.py index 9f0eb8641051b84c39c9e50d68f8800c9f7138d9..e7d8c3d9e3faf6f23b8c0e08004a8ce338f681a2 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -1031,6 +1031,26 @@ def test_fine_grain_svm(ctx_factory): assert np.array_equal(orig_ary*2, ary) +@pytest.mark.parametrize("dtype", [ + np.uint, + cl.cltypes.uint2, + ]) +def test_map_dtype(ctx_factory, dtype): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + dt = np.dtype(dtype) + + b = pyopencl.Buffer(ctx, + pyopencl.mem_flags.READ_ONLY, + dt.itemsize) + array, ev = pyopencl.enqueue_map_buffer(queue, b, pyopencl.map_flags.WRITE, 0, + (1,), dt) + with array.base: + print(array.dtype) + assert array.dtype == dt + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests. import pyopencl # noqa