From ffee8b44e499f66782d13a9a8c7d1f875af2ebdb Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <>
Date: Mon, 14 Mar 2011 19:09:26 -0400
Subject: [PATCH] Support vector types.

 doc/source/array.rst | 20 ++++++++++++++--
 doc/source/misc.rst  |  1 +
 pyopencl/    | 55 ++++++++++++++++++++++++++++++++++++++++++++
 pyopencl/    | 12 ++++++++--
 test/ | 18 +++++++++++++++
 5 files changed, 102 insertions(+), 4 deletions(-)

diff --git a/doc/source/array.rst b/doc/source/array.rst
index 0e435a7c..c6586eb0 100644
--- a/doc/source/array.rst
+++ b/doc/source/array.rst
@@ -1,8 +1,24 @@
-The :class:`Array` Class
+Multi-dimensional arrays on the Compute Device
 .. module:: pyopencl.array
+Vector Types
+.. class :: vec
+    All of OpenCL's supported vector types, such as `float3` and `long4` are
+    available as :mod:`numpy` data types within this class. These
+    :class:`numpy.dtype` instances have field names of `x`, `y`, `z`, and `w`
+    just like their OpenCL counterparts. They will work both for parameter passing
+    to kernels as well as for passing data back and forth between kernels and
+    Python code. For each type, a `make_type` function is also provided (e.g.
+    `make_float3(x,y,z)`).
+The :class:`Array` Class
 .. class:: DefaultAllocator(context, flags=pyopencl.mem_flags.READ_WRITE)
     An alias for :class:``.
diff --git a/doc/source/misc.rst b/doc/source/misc.rst
index f94b7162..53e8982a 100644
--- a/doc/source/misc.rst
+++ b/doc/source/misc.rst
@@ -93,6 +93,7 @@ Version 2011.1
 * Make construction of :class:`pyopencl.array.Array` more flexible (*cqa* argument.)
 * Add :ref:`memory-pools`.
+* Add vector types, see :class:`pyopencl.array.vec`.
 Version 0.92
diff --git a/pyopencl/ b/pyopencl/
index a88b8cce..71a2206f 100644
--- a/pyopencl/
+++ b/pyopencl/
@@ -36,7 +36,62 @@ import pyopencl as cl
 #from pytools import memoize_method
+# {{{ vector types
+class vec:
+    pass
+def _create_vector_types():
+    field_names = ["x", "y", "z", "w"]
+    name_to_dtype = {}
+    dtype_to_name = {}
+    counts = [2, 3, 4, 8, 16]
+    for base_name, base_type in [
+        ('char', np.int8),
+        ('uchar', np.uint8),
+        ('short', np.int16),
+        ('ushort', np.uint16),
+        ('int', np.uint32),
+        ('uint', np.uint32),
+        ('long', np.int64),
+        ('ulong', np.uint64),
+        ('float', np.float32),
+        ('double', np.float64),
+        ]:
+        for count in counts:
+            name = "%s%d" % (base_name, count)
+            titles = field_names[:count]
+            if len(titles) < count:
+                titles.extend((count-len(titles))*[None])
+            dtype = np.dtype(dict(
+                names=["s%d" % i for i in range(count)],
+                formats=[base_type]*count,
+                titles=titles))
+            name_to_dtype[name] = dtype
+            dtype_to_name[dtype] = name
+            setattr(vec, name, dtype)
+            my_field_names = ",".join(field_names[:count])
+            my_field_names_defaulted = ",".join(
+                    "%s=0" % fn for fn in field_names[:count])
+            setattr(vec, "make_"+name, 
+                    staticmethod(eval(
+                        "lambda %s: array((%s), dtype=my_dtype)"
+                        % (my_field_names_defaulted, my_field_names),
+                        dict(array=np.array, my_dtype=dtype))))
+    vec._dtype_to_c_name = dtype_to_name
+    vec._c_name_to_dtype = name_to_dtype
+# }}}
 # {{{ helper functionality
diff --git a/pyopencl/ b/pyopencl/
index eaf1ac11..76470a61 100644
--- a/pyopencl/
+++ b/pyopencl/
 import numpy as np
 from decorator import decorator
 import pyopencl as cl
+import pyopencl.array as cl_array
@@ -137,7 +138,10 @@ def dtype_to_ctype(dtype):
     elif dtype == np.complex128:
         return "complex double"
-        raise ValueError, "unable to map dtype '%s'" % dtype
+        try:
+            return cl_array.vec._dtype_to_c_name[dtype]
+        except KeyError:
+            raise ValueError, "unable to map dtype '%s'" % dtype
 # }}}
@@ -231,7 +235,11 @@ def parse_c_arg(c_arg):
     elif tp in ["char"]: dtype = np.int8
     elif tp in ["unsigned char"]: dtype = np.uint8
     elif tp in ["bool"]: dtype = np.bool
-    else: raise ValueError, "unknown type '%s'" % tp
+    else:
+        try:
+            return cl_array.vec._c_name_to_dtype[tp]
+        except KeyError:
+            raise ValueError("unknown type '%s'" % tp)
     return arg_class(dtype, name, vector_len)
diff --git a/test/ b/test/
index 99c75b8c..3b334cb9 100644
--- a/test/
+++ b/test/
@@ -17,6 +17,7 @@ def have_cl():
 if have_cl():
     import pyopencl as cl
+    import pyopencl.array as cl_array
     from import pytest_generate_tests_for_pyopencl \
             as pytest_generate_tests
@@ -324,9 +325,26 @@ class TestCL:
             assert MemoryPool.bin_number(asize) == bin_nr, s
             assert asize < asize*(1+1/8)
+    @pytools.test.mark_test.opencl
+    def test_vector_args(self, ctx_getter):
+        context = ctx_getter()
+        queue = cl.CommandQueue(context)
+        prg = cl.Program(context, """
+            __kernel void set_vec(float4 x, __global float4 *dest)
+            { dest[get_global_id(0)] = x; }
+            """).build()
+        x = cl_array.vec.make_float4(1,2,3,4)
+        dest = np.empty(50000, cl_array.vec.float4)
+        mf = cl.mem_flags
+        dest_buf = cl.Buffer(context, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=dest)
+        prg.set_vec(queue, dest.shape, None, x, dest_buf)
+        cl.enqueue_read_buffer(queue, dest_buf, dest).wait()
+        assert (dest == x).all()