From 7cdf8aa11b3dbb39672fd813002442d0d6048b35 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 30 Jul 2012 14:48:11 -0400
Subject: [PATCH] Add, document pyopencl.tools.match_dtype_to_c_struct.

---
 doc/source/array.rst |  19 +----
 doc/source/misc.rst  |   3 +-
 pyopencl/tools.py    | 180 +++++++++++++++++++++++++++++++++++++++++--
 3 files changed, 176 insertions(+), 26 deletions(-)

diff --git a/doc/source/array.rst b/doc/source/array.rst
index 560c68ae..40656ed3 100644
--- a/doc/source/array.rst
+++ b/doc/source/array.rst
@@ -53,24 +53,7 @@ about them using this function:
 This function helps with producing C/OpenCL declarations for structured
 :class:`numpy.dtype` instances:
 
-.. function:: dtype_to_c_struct(dtype)
-
-    Return a C structure declaration for *dtype*.
-
-    .. versionadded: 2012.2
-
-This example explains its use::
-
-    >>> import pyopencl as cl
-    >>> import pyopencl.tools
-    >>> import numpy as np
-    >>> t = np.dtype([("id", np.uint32), ("value", np.float32)])
-    >>> cl.tools.register_dtype(t, "id_val")
-    >>> print cl.tools.dtype_to_c_struct(t)
-    typedef struct {
-      unsigned id;
-      float value;
-    } id_val;
+.. autofunction:: match_dtype_to_c_struct
 
 .. currentmodule:: pyopencl.array
 
diff --git a/doc/source/misc.rst b/doc/source/misc.rst
index 0adf14c8..b3c907d5 100644
--- a/doc/source/misc.rst
+++ b/doc/source/misc.rst
@@ -80,7 +80,8 @@ Version 2012.2
     PyOpenCL's `git repository <https://github.com/inducer/pyopencl>`_
 
 * Vastly improved :ref:`custom-scan`.
-* Add :func:`pyopencl.tools.dtype_to_c_struct`.
+* Add :func:`pyopencl.tools.match_dtype_to_c_struct`,
+  for better integration of the CL and :mod:`numpy` type systems.
 * More/improved Bessel functions.
   See `the source <https://github.com/inducer/pyopencl/tree/master/src/cl>`_.
 * Add :envvar:`PYOPENCL_NO_CACHE` environment variable to aid debugging
diff --git a/pyopencl/tools.py b/pyopencl/tools.py
index 6da2d17b..8ebf0cd0 100644
--- a/pyopencl/tools.py
+++ b/pyopencl/tools.py
@@ -32,6 +32,7 @@ OTHER DEALINGS IN THE SOFTWARE.
 import numpy as np
 from decorator import decorator
 import pyopencl as cl
+from pytools import memoize
 
 from pyopencl.compyte.dtypes import (
         register_dtype, _fill_dtype_registry,
@@ -294,21 +295,186 @@ def get_gl_sharing_context_properties():
 
 
 
-def dtype_to_c_struct(dtype):
-    dtype = np.dtype(dtype)
+class _CDeclList:
+    def __init__(self, device):
+        self.device = device
+        self.declared_dtypes = set()
+        self.declarations = []
+        self.saw_double = False
+        self.saw_complex = False
+
+    def add_dtype(self, dtype):
+        dtype = np.dtype(dtype)
+
+        if dtype in [np.float64 or np.complex128]:
+            self.saw_double = True
+
+        if dtype.kind == "c":
+            self.saw_complex = True
+
+        if dtype.kind != "V":
+            return
+
+        if dtype in self.declared_dtypes:
+            return
+
+        for name, (field_dtype, offset) in dtype.fields.iteritems():
+            self.add_dtype(field_dtype)
+
+        _, cdecl = match_dtype_to_c_struct(self.device, dtype_to_ctype(dtype), dtype)
+
+        self.declarations.append(cdecl)
+        self.declared_dtypes.add(dtype)
+
+    def get_declarations(self):
+        result = "\n\n".join(self.declarations)
+
+        if self.saw_double:
+            result = (
+                    "#pragma OPENCL EXTENSION cl_khr_fp64: enable\n"
+                    "#define PYOPENCL_DEFINE_CDOUBLE\n"
+                    + result)
+
+        if self.saw_complex:
+            result = (
+                    "#include <pyopencl-complex.h>\n\n"
+                    + result)
+
+        return result
+
+
+@memoize
+def match_dtype_to_c_struct(device, name, dtype, context=None):
+    """Return a tuple `(dtype, c_decl)` such that the C struct declaration
+    in `c_decl` and the structure :class:`numpy.dtype` instance `dtype`
+    have the same memory layout.
+
+    Note that *dtype* may be modified from the value that was passed in,
+    for example to insert padding.
+
+    (As a remark on implementation, this routine runs a small kernel on
+    the given *device* to ensure that :mod:`numpy` and C offsets and
+    sizes match.)
+
+    .. versionadded: 2012.2
+
+    This example explains the use of this function::
+
+        >>> import numpy as np
+        >>> import pyopencl as cl
+        >>> import pyopencl.tools
+        >>> dtype = np.dtype([("id", np.uint32), ("value", np.float32)])
+        >>> dtype, c_decl = pyopencl.tools.match_dtype_to_c_struct(ctx.devices[0], 'id_val', dtype)
+        >>> print c_decl
+        typedef struct {
+          unsigned id;
+          float value;
+        } id_val;
+        >>> print dtype
+        [('id', '<u4'), ('value', '<f4')]
+        >>> cl.tools.register_dtype(dtype, 'id_val')
+
+    As this example shows, it is important to call :func:`register_dtype` on
+    the modified `dtype` returned by this function, not the original one.
+    """
 
     fields = sorted(dtype.fields.iteritems(),
             key=lambda (name, (dtype, offset)): offset)
 
-    # FIXME check that this matches C alignment fulres
     c_fields = []
-    for name, (field_dtype, offset) in fields:
-        c_fields.append("  %s %s;" % (dtype_to_ctype(field_dtype), name))
+    for field_name, (field_dtype, offset) in fields:
+        c_fields.append("  %s %s;" % (dtype_to_ctype(field_dtype), field_name))
 
-    return "typedef struct {\n%s\n} %s;" % (
+    c_decl = "typedef struct {\n%s\n} %s;" % (
             "\n".join(c_fields),
-            dtype_to_ctype(dtype))
+            name)
+
+    cdl = _CDeclList(device)
+    for field_name, (field_dtype, offset) in fields:
+        cdl.add_dtype(field_dtype)
+
+    pre_decls = cdl.get_declarations()
+
+    offset_code = "\n".join(
+            "result[%d] = pycl_offsetof(%s, %s);" % (i+1, name, field_name)
+            for i, (field_name, (field_dtype, offset)) in enumerate(fields))
+
+    src = r"""
+        #define pycl_offsetof(st, m) \
+                 ((size_t) ( (char *)&((st *)0)->m - (char *)0 ))
+
+        %(pre_decls)s
+
+        %(my_decl)s
+
+        __kernel void get_size_and_offsets(__global size_t *result)
+        {
+            result[0] = sizeof(%(my_type)s);
+            %(offset_code)s
+        }
+    """ % dict(
+            pre_decls=pre_decls,
+            my_decl=c_decl,
+            my_type=name,
+            offset_code=offset_code)
+
+    if context is None:
+        context = cl.Context([device])
+
+    queue = cl.CommandQueue(context)
+
+    prg = cl.Program(context, src)
+    knl = prg.build(devices=[device]).get_size_and_offsets
+
+    import pyopencl.array
+    result_buf = cl.array.empty(queue, 1+len(fields), np.uintp)
+    knl(queue, (1,), None, result_buf.data)
+    size_and_offsets = result_buf.get()
+
+    result_buf.data.release()
+    del knl
+    del prg
+    del queue
+    del context
+
+    dtype_arg_dict = dict(
+            names=[field_name for field_name, (field_dtype, offset) in fields],
+            formats=[field_dtype for field_name, (field_dtype, offset) in fields],
+            offsets=[int(x) for x in size_and_offsets[1:]],
+            itemsize=int(size_and_offsets[0]),
+            )
+    dtype = np.dtype(dtype_arg_dict)
+
+    if dtype.itemsize != size_and_offsets[0]:
+        # "Old" versions of numpy (1.6.x?) silently ignore "itemsize". Boo.
+        dtype_arg_dict["names"].append("_pycl_size_fixer")
+        dtype_arg_dict["formats"].append(np.uint8)
+        dtype_arg_dict["offsets"].append(int(size_and_offsets[0])-1)
+        dtype = np.dtype(dtype_arg_dict)
+
+    assert dtype.itemsize == size_and_offsets[0]
+
+    return dtype, c_decl
+
+
+
+
+@memoize
+def dtype_to_c_struct(device, dtype):
+    matched_dtype, c_decl = match_dtype_to_c_struct(
+            device, dtype_to_ctype(dtype), dtype)
+
+    def dtypes_match():
+        result = len(dtype.fields) == len(matched_dtype.fields)
+
+        for name, val in dtype.fields.iteritems():
+            result = result and matched_dtype.fields[name] == val
+
+        return result
+
+    assert dtypes_match()
 
+    return c_decl
 
 
 
-- 
GitLab