From 08507a3353ad7cf61c05a0c18b092e5c730ac66c Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 10 Sep 2018 22:11:52 -0500
Subject: [PATCH] Fix, test cl{Compile,Link}Program wrappers

---
 pyopencl/__init__.py   |  7 +++++--
 src/wrap_cl.hpp        | 32 ++++++++++++++++++++++++++++++--
 src/wrap_constants.cpp |  1 +
 test/test_wrapper.py   | 25 +++++++++++++++++++++++++
 4 files changed, 61 insertions(+), 4 deletions(-)

diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py
index d0904c29..bc8cf16c 100644
--- a/pyopencl/__init__.py
+++ b/pyopencl/__init__.py
@@ -555,7 +555,8 @@ class Program(object):
     def compile(self, options=[], devices=None, headers=[]):
         options_bytes, _ = self._process_build_options(self._context, options)
 
-        return self._prg.compile(options_bytes, devices, headers)
+        self._get_prg().compile(options_bytes, devices, headers)
+        return self
 
     def __eq__(self, other):
         return self._get_prg() == other._get_prg()
@@ -577,7 +578,9 @@ def create_program_with_built_in_kernels(context, devices, kernel_names):
 
 def link_program(context, programs, options=[], devices=None):
     options_bytes, _ = Program._process_build_options(context, options)
-    return Program(_Program.link(context, programs, options_bytes, devices))
+    programs = [prg._get_prg() for prg in programs]
+    raw_prg = _Program.link(context, programs, options_bytes, devices)
+    return Program(raw_prg)
 
 # }}}
 
diff --git a/src/wrap_cl.hpp b/src/wrap_cl.hpp
index ace4bd25..601c3dd1 100644
--- a/src/wrap_cl.hpp
+++ b/src/wrap_cl.hpp
@@ -394,6 +394,8 @@
 
 namespace pyopencl
 {
+  class program;
+
   // {{{ error
   class error : public std::runtime_error
   {
@@ -401,11 +403,30 @@ namespace pyopencl
       std::string m_routine;
       cl_int m_code;
 
+      // This is here because clLinkProgram returns a program
+      // object *just* so that there is somewhere for it to
+      // stuff the linker logs. :/
+      bool m_program_initialized;
+      cl_program m_program;
+
     public:
       error(const char *routine, cl_int c, const char *msg="")
-        : std::runtime_error(msg), m_routine(routine), m_code(c)
+        : std::runtime_error(msg), m_routine(routine), m_code(c),
+        m_program_initialized(false), m_program(nullptr)
       { }
 
+      error(const char *routine, cl_program prg, cl_int c,
+          const char *msg="")
+        : std::runtime_error(msg), m_routine(routine), m_code(c),
+        m_program_initialized(true), m_program(prg)
+      { }
+
+      virtual ~error()
+      {
+        if (m_program_initialized)
+          clReleaseProgram(m_program);
+      }
+
       const std::string &routine() const
       {
         return m_routine;
@@ -423,6 +444,8 @@ namespace pyopencl
             || code() == CL_OUT_OF_HOST_MEMORY);
       }
 
+      program *get_program() const;
+
   };
 
   // }}}
@@ -4075,7 +4098,7 @@ namespace pyopencl
         &status_code);
 
     if (status_code != CL_SUCCESS)
-      throw pyopencl::error("clLinkPorgram", status_code);
+      throw pyopencl::error("clLinkProgram", result, status_code);
 
     try
     {
@@ -4736,6 +4759,11 @@ namespace pyopencl
 
   // {{{ deferred implementation bits
 
+  inline program *error::get_program() const
+  {
+    return new program(m_program, /* retain */ true);
+  }
+
   inline py::object create_mem_object_wrapper(cl_mem mem, bool retain=true)
   {
     cl_mem_object_type mem_obj_type;
diff --git a/src/wrap_constants.cpp b/src/wrap_constants.cpp
index 7b6a97f1..597e49c2 100644
--- a/src/wrap_constants.cpp
+++ b/src/wrap_constants.cpp
@@ -108,6 +108,7 @@ void pyopencl_expose_constants(py::module &m)
       .DEF_SIMPLE_METHOD(code)
       .DEF_SIMPLE_METHOD(what)
       .DEF_SIMPLE_METHOD(is_out_of_memory)
+      .def("_program", &cls::get_program)
       ;
   }
 
diff --git a/test/test_wrapper.py b/test/test_wrapper.py
index 4d729642..338e5fc5 100644
--- a/test/test_wrapper.py
+++ b/test/test_wrapper.py
@@ -1042,6 +1042,31 @@ def test_map_dtype(ctx_factory, dtype):
         assert array.dtype == dt
 
 
+def test_compile_link(ctx_factory):
+    ctx = ctx_factory()
+
+    if ctx._get_cl_version() < (1, 2) or cl.get_cl_header_version() < (1, 2):
+        pytest.skip("Context and ICD loader must understand CL1.2 for compile/link")
+
+    queue = cl.CommandQueue(ctx)
+    vsink_prg = cl.Program(ctx, """//CL//
+        void value_sink(float x)
+        {
+        }
+        """).compile()
+    main_prg = cl.Program(ctx, """//CL//
+        void value_sink(float x);
+
+        __kernel void experiment()
+        {
+            value_sink(3.1415f + get_global_id(0));
+        }
+        """).compile()
+    z = cl.link_program(ctx, [vsink_prg, main_prg], devices=ctx.devices)
+    z.experiment(queue, (128**2,), (128,))
+    queue.finish()
+
+
 if __name__ == "__main__":
     # make sure that import failures get reported, instead of skipping the tests.
     import pyopencl  # noqa
-- 
GitLab