From ffee8b44e499f66782d13a9a8c7d1f875af2ebdb Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 14 Mar 2011 19:09:26 -0400
Subject: [PATCH] Support vector types.

---
 doc/source/array.rst | 20 ++++++++++++++--
 doc/source/misc.rst  |  1 +
 pyopencl/array.py    | 55 ++++++++++++++++++++++++++++++++++++++++++++
 pyopencl/tools.py    | 12 ++++++++--
 test/test_wrapper.py | 18 +++++++++++++++
 5 files changed, 102 insertions(+), 4 deletions(-)

diff --git a/doc/source/array.rst b/doc/source/array.rst
index 0e435a7c..c6586eb0 100644
--- a/doc/source/array.rst
+++ b/doc/source/array.rst
@@ -1,8 +1,24 @@
-The :class:`Array` Class
-========================
+Multi-dimensional arrays on the Compute Device
+==============================================
 
 .. module:: pyopencl.array
 
+Vector Types
+------------
+
+.. class :: vec
+
+    All of OpenCL's supported vector types, such as `float3` and `long4` are
+    available as :mod:`numpy` data types within this class. These
+    :class:`numpy.dtype` instances have field names of `x`, `y`, `z`, and `w`
+    just like their OpenCL counterparts. They will work both for parameter passing
+    to kernels as well as for passing data back and forth between kernels and
+    Python code. For each type, a `make_type` function is also provided (e.g.
+    `make_float3(x,y,z)`).
+
+The :class:`Array` Class
+------------------------
+
 .. class:: DefaultAllocator(context, flags=pyopencl.mem_flags.READ_WRITE)
 
     An alias for :class:`pyopencl.tools.CLAllocator`.
diff --git a/doc/source/misc.rst b/doc/source/misc.rst
index f94b7162..53e8982a 100644
--- a/doc/source/misc.rst
+++ b/doc/source/misc.rst
@@ -93,6 +93,7 @@ Version 2011.1
   :func:`pyopencl.array.arange`.
 * Make construction of :class:`pyopencl.array.Array` more flexible (*cqa* argument.)
 * Add :ref:`memory-pools`.
+* Add vector types, see :class:`pyopencl.array.vec`.
 
 Version 0.92
 ------------
diff --git a/pyopencl/array.py b/pyopencl/array.py
index a88b8cce..71a2206f 100644
--- a/pyopencl/array.py
+++ b/pyopencl/array.py
@@ -36,7 +36,62 @@ import pyopencl as cl
 #from pytools import memoize_method
 
 
+# {{{ vector types
+
+class vec:
+    pass
+
+def _create_vector_types():
+    field_names = ["x", "y", "z", "w"]
+
+    name_to_dtype = {}
+    dtype_to_name = {}
+
+    counts = [2, 3, 4, 8, 16]
+    for base_name, base_type in [
+        ('char', np.int8),
+        ('uchar', np.uint8),
+        ('short', np.int16),
+        ('ushort', np.uint16),
+        ('int', np.uint32),
+        ('uint', np.uint32),
+        ('long', np.int64),
+        ('ulong', np.uint64),
+        ('float', np.float32),
+        ('double', np.float64),
+        ]:
+        for count in counts:
+            name = "%s%d" % (base_name, count)
+
+            titles = field_names[:count]
+            if len(titles) < count:
+                titles.extend((count-len(titles))*[None])
+
+            dtype = np.dtype(dict(
+                names=["s%d" % i for i in range(count)],
+                formats=[base_type]*count,
+                titles=titles))
+
+            name_to_dtype[name] = dtype
+            dtype_to_name[dtype] = name
+
+            setattr(vec, name, dtype)
+
+            my_field_names = ",".join(field_names[:count])
+            my_field_names_defaulted = ",".join(
+                    "%s=0" % fn for fn in field_names[:count])
+            setattr(vec, "make_"+name, 
+                    staticmethod(eval(
+                        "lambda %s: array((%s), dtype=my_dtype)"
+                        % (my_field_names_defaulted, my_field_names),
+                        dict(array=np.array, my_dtype=dtype))))
+
+    vec._dtype_to_c_name = dtype_to_name
+    vec._c_name_to_dtype = name_to_dtype
+
+_create_vector_types()
 
+# }}}
 
 # {{{ helper functionality
 
diff --git a/pyopencl/tools.py b/pyopencl/tools.py
index eaf1ac11..76470a61 100644
--- a/pyopencl/tools.py
+++ b/pyopencl/tools.py
@@ -32,6 +32,7 @@ OTHER DEALINGS IN THE SOFTWARE.
 import numpy as np
 from decorator import decorator
 import pyopencl as cl
+import pyopencl.array as cl_array
 
 
 
@@ -137,7 +138,10 @@ def dtype_to_ctype(dtype):
     elif dtype == np.complex128:
         return "complex double"
     else:
-        raise ValueError, "unable to map dtype '%s'" % dtype
+        try:
+            return cl_array.vec._dtype_to_c_name[dtype]
+        except KeyError:
+            raise ValueError, "unable to map dtype '%s'" % dtype
 
 # }}}
 
@@ -231,7 +235,11 @@ def parse_c_arg(c_arg):
     elif tp in ["char"]: dtype = np.int8
     elif tp in ["unsigned char"]: dtype = np.uint8
     elif tp in ["bool"]: dtype = np.bool
-    else: raise ValueError, "unknown type '%s'" % tp
+    else:
+        try:
+            return cl_array.vec._c_name_to_dtype[tp]
+        except KeyError:
+            raise ValueError("unknown type '%s'" % tp)
 
     return arg_class(dtype, name, vector_len)
 
diff --git a/test/test_wrapper.py b/test/test_wrapper.py
index 99c75b8c..3b334cb9 100644
--- a/test/test_wrapper.py
+++ b/test/test_wrapper.py
@@ -17,6 +17,7 @@ def have_cl():
 
 if have_cl():
     import pyopencl as cl
+    import pyopencl.array as cl_array
     from pyopencl.tools import pytest_generate_tests_for_pyopencl \
             as pytest_generate_tests
 
@@ -324,9 +325,26 @@ class TestCL:
             assert MemoryPool.bin_number(asize) == bin_nr, s
             assert asize < asize*(1+1/8)
 
+    @pytools.test.mark_test.opencl
+    def test_vector_args(self, ctx_getter):
+        context = ctx_getter()
+        queue = cl.CommandQueue(context)
+
+        prg = cl.Program(context, """
+            __kernel void set_vec(float4 x, __global float4 *dest)
+            { dest[get_global_id(0)] = x; }
+            """).build()
+
+        x = cl_array.vec.make_float4(1,2,3,4)
+        dest = np.empty(50000, cl_array.vec.float4)
+        mf = cl.mem_flags
+        dest_buf = cl.Buffer(context, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=dest)
 
+        prg.set_vec(queue, dest.shape, None, x, dest_buf)
 
+        cl.enqueue_read_buffer(queue, dest_buf, dest).wait()
 
+        assert (dest == x).all()
 
 
 
-- 
GitLab