From 9a832d72b968679563834722c6f6db1cab963497 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992@gmail.com>
Date: Sun, 22 Jun 2014 22:35:05 +0800
Subject: [PATCH] _Program.compile, _Program.link, _Program.all_kernels

---
 TODOs                             |  3 --
 pyopencl/c_wrapper/wrap_cl_core.h |  7 +++
 pyopencl/cffi_cl.py               | 24 +++++++++++
 src/c_wrapper/program.cpp         | 71 ++++++++++++++++++++++++++++---
 src/c_wrapper/program.h           | 44 +++----------------
 src/c_wrapper/utils.h             | 10 ++---
 6 files changed, 106 insertions(+), 53 deletions(-)

diff --git a/TODOs b/TODOs
index 9c826a55..dec695e7 100644
--- a/TODOs
+++ b/TODOs
@@ -5,9 +5,6 @@
 - enqueue_nd_range_kernel size/offset mess
 
 - CommandQueue.set_property
-- _Program.compile
-- _Program.link
-- _Program.all_kernels
 - ?LocalMemory
 - get_apple_cgl_share_group
 - GLBuffer
diff --git a/pyopencl/c_wrapper/wrap_cl_core.h b/pyopencl/c_wrapper/wrap_cl_core.h
index 396014b9..20fa204d 100644
--- a/pyopencl/c_wrapper/wrap_cl_core.h
+++ b/pyopencl/c_wrapper/wrap_cl_core.h
@@ -113,6 +113,13 @@ error *program__create_with_builtin_kernels(clobj_t *_prg, clobj_t _ctx,
                                             const clobj_t *_devs,
                                             uint32_t num_devs,
                                             const char *names);
+error *program__compile(clobj_t _prg, const char *opts, const clobj_t *_devs,
+                        size_t num_devs, const clobj_t *_prgs,
+                        const char *const *names, size_t num_hdrs);
+error *program__link(clobj_t *_prg, clobj_t _ctx, const clobj_t *_prgs,
+                     size_t num_prgs, const char *opts,
+                     const clobj_t *_devs, size_t num_devs);
+error *program__all_kernels(clobj_t _prg, clobj_t **_knl, uint32_t *size);
 // Sampler
 error *create_sampler(clobj_t *sampler, clobj_t context, int norm_coords,
                       cl_addressing_mode am, cl_filter_mode fm);
diff --git a/pyopencl/cffi_cl.py b/pyopencl/cffi_cl.py
index 422fad54..8438355c 100644
--- a/pyopencl/cffi_cl.py
+++ b/pyopencl/cffi_cl.py
@@ -881,6 +881,24 @@ class _Program(_Common):
             self.ptr, device.ptr, param, info))
         return _generic_info_to_python(info)
 
+    def compile(options="", devices=None, headers=[]):
+        _devs, num_devs = _clobj_list(devices)
+        _prgs, names = zip(*((prg.ptr, _to_cstring(name))
+                             for (name, prg) in headers))
+        _handle_error(_lib.program__compile(
+            self.ptr, _to_cstring(options), _devs, num_devs,
+            _prgs, names, len(names)))
+
+    @classmethod
+    def link(cls, context, programs, options="", devices=None):
+        _devs, num_devs = _clobj_list(devices)
+        _prgs, num_prgs = _clobj_list(programs)
+        _prg = _ffi.new('clobj_t*')
+        _handle_error(_lib.program__link(
+            _prg, context.ptr, _prgs, num_prgs, _to_cstring(options),
+            _devs, num_devs))
+        return cls._create(_prg[0])
+
     @classmethod
     def create_with_builtin_kernels(cls, context, devices, kernel_names):
         _devs, num_devs = _clobj_list(devices)
@@ -889,6 +907,12 @@ class _Program(_Common):
             _prg, context.ptr, _devs, num_devs, _to_cstring(kernel_names)))
         return cls._create(_prg[0])
 
+    def all_kernels(self):
+        knls = _CArray(_ffi.new('clobj_t**'))
+        _handle_error(_lib.platform__get_devices(
+            self.ptr, knl.ptr, knl.size))
+        return [Kernel._create(knl.ptr[0][i]) for i in xrange(knl.size[0])]
+
 # }}}
 
 
diff --git a/src/c_wrapper/program.cpp b/src/c_wrapper/program.cpp
index 286649cc..39f30b29 100644
--- a/src/c_wrapper/program.cpp
+++ b/src/c_wrapper/program.cpp
@@ -2,6 +2,7 @@
 #include "device.h"
 #include "context.h"
 #include "clhelper.h"
+#include "kernel.h"
 
 namespace pyopencl {
 
@@ -95,18 +96,37 @@ program::get_build_info(const device *dev, cl_program_build_info param) const
     }
 }
 
+#if PYOPENCL_CL_VERSION >= 0x1020
+void
+program::compile(const char *opts, const clobj_t *_devs, size_t num_devs,
+                 const clobj_t *_prgs, const char *const *names,
+                 size_t num_hdrs)
+{
+    const auto devs = buf_from_class<device>(_devs, num_devs);
+    const auto prgs = buf_from_class<program>(_prgs, num_hdrs);
+    pyopencl_call_guarded(clCompileProgram, this, devs, opts, prgs,
+                          buf_arg(names, num_hdrs), nullptr, nullptr);
+}
+#endif
+
+pyopencl_buf<clobj_t>
+program::all_kernels()
+{
+    cl_uint num_knls;
+    pyopencl_call_guarded(clCreateKernelsInProgram, this, 0, nullptr,
+                          buf_arg(num_knls));
+    pyopencl_buf<cl_kernel> knls(num_knls);
+    pyopencl_call_guarded(clCreateKernelsInProgram, this, knls,
+                          buf_arg(num_knls));
+    return buf_to_base<kernel>(knls, true);
+}
+
 }
 
 // c wrapper
 // Import all the names in pyopencl namespace for c wrappers.
 using namespace pyopencl;
 
-typedef cl_program (*_clCreateProgramWithSourceType)(
-    cl_context, cl_uint, const char* const*, const size_t*, cl_int*);
-
-const static _clCreateProgramWithSourceType _clCreateProgramWithSource =
-    reinterpret_cast<_clCreateProgramWithSourceType>(clCreateProgramWithSource);
-
 // Program
 error*
 create_program_with_source(clobj_t *prog, clobj_t _ctx, const char *_src)
@@ -116,7 +136,7 @@ create_program_with_source(clobj_t *prog, clobj_t _ctx, const char *_src)
             const auto &src = _src;
             const size_t length = strlen(src);
             cl_program result = pyopencl_call_guarded(
-                _clCreateProgramWithSource, ctx, len_arg(src), buf_arg(length));
+                clCreateProgramWithSource, ctx, len_arg(src), buf_arg(length));
             *prog = new_program(result, KND_SOURCE);
         });
 }
@@ -184,4 +204,41 @@ program__create_with_builtin_kernels(clobj_t *_prg, clobj_t _ctx,
             *_prg = new_program(prg);
         });
 }
+
+error*
+program__compile(clobj_t _prg, const char *opts, const clobj_t *_devs,
+                 size_t num_devs, const clobj_t *_prgs,
+                 const char *const *names, size_t num_hdrs)
+{
+    auto prg = static_cast<program*>(_prg);
+    return c_handle_error([&] {
+            prg->compile(opts, _devs, num_devs, _prgs, names, num_hdrs);
+        });
+}
+
+error*
+program__link(clobj_t *_prg, clobj_t _ctx, const clobj_t *_prgs,
+              size_t num_prgs, const char *opts, const clobj_t *_devs,
+              size_t num_devs)
+{
+    const auto devs = buf_from_class<device>(_devs, num_devs);
+    const auto prgs = buf_from_class<program>(_prgs, num_prgs);
+    auto ctx = static_cast<context*>(_ctx);
+    return c_handle_error([&] {
+            auto prg = pyopencl_call_guarded(clLinkProgram, ctx, devs, opts,
+                                             prgs, nullptr, nullptr);
+            *_prg = new_program(prg);
+        });
+}
 #endif
+
+error*
+program__all_kernels(clobj_t _prg, clobj_t **_knl, uint32_t *size)
+{
+    auto prg = static_cast<program*>(_prg);
+    return c_handle_error([&] {
+            auto knls = prg->all_kernels();
+            *size = knls.len();
+            *_knl = knls.release();
+        });
+}
diff --git a/src/c_wrapper/program.h b/src/c_wrapper/program.h
index 8e148fff..e0b6f640 100644
--- a/src/c_wrapper/program.h
+++ b/src/c_wrapper/program.h
@@ -45,44 +45,12 @@ public:
     generic_info get_info(cl_uint param_name) const;
     PYOPENCL_USE_RESULT generic_info
     get_build_info(const device *dev, cl_program_build_info param_name) const;
-
-    // #if PYOPENCL_CL_VERSION >= 0x1020
-    //       void compile(std::string options, py::object py_devices,
-    //           py::object py_headers)
-    //       {
-    //         PYOPENCL_PARSE_PY_DEVICES;
-
-    //         // {{{ pick apart py_headers
-    //         // py_headers is a list of tuples *(name, program)*
-
-    //         std::vector<std::string> header_names;
-    //         std::vector<cl_program> programs;
-    //         PYTHON_FOREACH(name_hdr_tup, py_headers)
-    //         {
-    //           if (py::len(name_hdr_tup) != 2)
-    //             throw error("Program.compile", CL_INVALID_VALUE,
-    //                 "epxected (name, header) tuple in headers list");
-    //           std::string name = py::extract<std::string const &>(name_hdr_tup[0]);
-    //           program &prg = py::extract<program &>(name_hdr_tup[1]);
-
-    //           header_names.push_back(name);
-    //           programs.push_back(prg.data());
-    //         }
-
-    //         std::vector<const char *> header_name_ptrs;
-    //         BOOST_FOREACH(std::string const &name, header_names)
-    //           header_name_ptrs.push_back(name.c_str());
-
-    //         // }}}
-
-    //         PYOPENCL_CALL_GUARDED(clCompileProgram,
-    //             (this, num_devices, devices,
-    //              options.c_str(), header_names.size(),
-    //              programs.empty() ? nullptr : &programs.front(),
-    //              header_name_ptrs.empty() ? nullptr : &header_name_ptrs.front(),
-    //              0, 0));
-    //       }
-    // #endif
+#if PYOPENCL_CL_VERSION >= 0x1020
+    void compile(const char *opts, const clobj_t *_devs, size_t num_devs,
+                 const clobj_t *_prgs, const char *const *names,
+                 size_t num_hdrs);
+#endif
+    pyopencl_buf<clobj_t> all_kernels();
 };
 
 extern template void print_clobj<program>(std::ostream&, const program*);
diff --git a/src/c_wrapper/utils.h b/src/c_wrapper/utils.h
index 32c7beaf..ea7ebafd 100644
--- a/src/c_wrapper/utils.h
+++ b/src/c_wrapper/utils.h
@@ -226,15 +226,15 @@ public:
     ArgBuffer(ArgBuffer<T, AT> &&other) noexcept
         : ArgBuffer(other.m_buf, other.m_len)
     {}
-    PYOPENCL_INLINE T*
+    PYOPENCL_INLINE rm_const_t<T>*
     get() const noexcept
     {
-        return m_buf;
+        return const_cast<rm_const_t<T>*>(m_buf);
     }
     PYOPENCL_INLINE T&
     operator[](int i) const
     {
-        return this->get()[i];
+        return m_buf[i];
     }
     PYOPENCL_INLINE size_t
     len() const noexcept
@@ -288,8 +288,8 @@ struct _ArgBufferConverter;
 template<typename Buff>
 struct _ArgBufferConverter<Buff,
                            enable_if_t<Buff::arg_type == ArgType::None> > {
-    static PYOPENCL_INLINE typename Buff::type*
-    convert(Buff &buff)
+    static PYOPENCL_INLINE auto
+    convert(Buff &buff) -> decltype(buff.get())
     {
         return buff.get();
     }
-- 
GitLab