From 5e736c0cba5512872670143c73b4d1b815594961 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 14 Sep 2017 00:03:01 -0500
Subject: [PATCH] Fix vector-typed memory maps (#203 on github)

---
 pyopencl/cffi_cl.py  | 13 +++++++++----
 test/test_wrapper.py | 20 ++++++++++++++++++++
 2 files changed, 29 insertions(+), 4 deletions(-)

diff --git a/pyopencl/cffi_cl.py b/pyopencl/cffi_cl.py
index ea4d01a9..d98f169d 100644
--- a/pyopencl/cffi_cl.py
+++ b/pyopencl/cffi_cl.py
@@ -2187,8 +2187,11 @@ def enqueue_map_buffer(queue, buf, flags, offset, shape, dtype,
     _handle_error(_lib.enqueue_map_buffer(_event, _map, queue.ptr, buf.ptr,
                                           flags, offset, byte_size, c_wait_for,
                                           num_wait_for, bool(is_blocking)))
-    return (np.asarray(MemoryMap._create(_map[0], shape, dtype.str, strides)),
-            Event._create(_event[0]))
+    mmap = MemoryMap._create(_map[0], shape, dtype.str, strides)
+    ary = np.asarray(mmap)
+    ary.dtype = dtype
+
+    return (ary, Event._create(_event[0]))
 
 
 def _enqueue_fill_buffer(queue, mem, pattern, offset, size, wait_for=None):
@@ -2283,8 +2286,10 @@ def enqueue_map_image(queue, img, flags, origin, region, shape, dtype,
                                          flags, origin, origin_l, region,
                                          region_l, _row_pitch, _slice_pitch,
                                          c_wait_for, num_wait_for, is_blocking))
-    return (np.asarray(MemoryMap._create(_map[0], shape, dtype.str, strides)),
-            Event._create(_event[0]), _row_pitch[0], _slice_pitch[0])
+    mmap = MemoryMap._create(_map[0], shape, dtype.str, strides)
+    ary = np.asarray(mmap)
+    ary.dtype = dtype
+    return (ary, Event._create(_event[0]), _row_pitch[0], _slice_pitch[0])
 
 
 def enqueue_fill_image(queue, img, color, origin, region, wait_for=None):
diff --git a/test/test_wrapper.py b/test/test_wrapper.py
index 9f0eb864..e7d8c3d9 100644
--- a/test/test_wrapper.py
+++ b/test/test_wrapper.py
@@ -1031,6 +1031,26 @@ def test_fine_grain_svm(ctx_factory):
     assert np.array_equal(orig_ary*2, ary)
 
 
+@pytest.mark.parametrize("dtype", [
+    np.uint,
+    cl.cltypes.uint2,
+    ])
+def test_map_dtype(ctx_factory, dtype):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    dt = np.dtype(dtype)
+
+    b = pyopencl.Buffer(ctx,
+                        pyopencl.mem_flags.READ_ONLY,
+                        dt.itemsize)
+    array, ev = pyopencl.enqueue_map_buffer(queue, b, pyopencl.map_flags.WRITE, 0,
+                                            (1,), dt)
+    with array.base:
+        print(array.dtype)
+        assert array.dtype == dt
+
+
 if __name__ == "__main__":
     # make sure that import failures get reported, instead of skipping the tests.
     import pyopencl  # noqa
-- 
GitLab