From b5fdc5f2677a79063fcaee807e3ea64c90ec248a Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 9 Aug 2015 13:28:18 -0500
Subject: [PATCH] Test, document Event.set_callback

---
 doc/runtime.rst      |  8 ++++++++
 examples/demo.py     |  3 +--
 pyopencl/cffi_cl.py  |  4 ++--
 test/test_wrapper.py | 46 ++++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 57 insertions(+), 4 deletions(-)

diff --git a/doc/runtime.rst b/doc/runtime.rst
index 52d2a521..eafdeeca 100644
--- a/doc/runtime.rst
+++ b/doc/runtime.rst
@@ -278,6 +278,14 @@ Command Queues and Events
     .. automethod:: from_int_ptr
     .. autoattribute:: int_ptr
 
+    .. method:: set_callback(type, cb)
+
+        Add the callback *cb* with signature ``cb(status)`` to the callback
+        queue for the event status *type* (one of the values of
+        :class:`command_execution_status`, except :attr:`command_execution_status.QUEUED`).
+
+        See the OpenCL specification for restrictions on what *cb* may and may not do.
+
     |comparable|
 
 .. function:: wait_for_events(events)
diff --git a/examples/demo.py b/examples/demo.py
index 8fcdd5b4..b6a16939 100644
--- a/examples/demo.py
+++ b/examples/demo.py
@@ -1,8 +1,7 @@
-from __future__ import absolute_import
-from __future__ import print_function
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+from __future__ import absolute_import, print_function
 import numpy as np
 import pyopencl as cl
 
diff --git a/pyopencl/cffi_cl.py b/pyopencl/cffi_cl.py
index 24a1ac5c..f1dcf81d 100644
--- a/pyopencl/cffi_cl.py
+++ b/pyopencl/cffi_cl.py
@@ -1096,9 +1096,9 @@ class Event(_Common):
     def wait(self):
         _handle_error(_lib.event__wait(self.ptr))
 
-    def set_callback(self, _type, cb, *args, **kwargs):
+    def set_callback(self, _type, cb):
         def _func(status):
-            cb(status, *args, **kwargs)
+            cb(status)
         _handle_error(_lib.event__set_callback(self.ptr, _type,
                                                _ffi.new_handle(_func)))
 
diff --git a/test/test_wrapper.py b/test/test_wrapper.py
index 61c6d45a..23d26aef 100644
--- a/test/test_wrapper.py
+++ b/test/test_wrapper.py
@@ -787,6 +787,52 @@ def test_program_valued_get_info(ctx_factory):
     knl.program.binaries[0]
 
 
+def test_event_set_callback(ctx_factory):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    if ctx._get_cl_version() < (1, 1):
+        pytest.skip("OpenCL 1.1 or newer required fro set_callback")
+
+    a_np = np.random.rand(50000).astype(np.float32)
+    b_np = np.random.rand(50000).astype(np.float32)
+
+    got_called = []
+
+    def cb(status):
+        got_called.append(status)
+
+    mf = cl.mem_flags
+    a_g = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=a_np)
+    b_g = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=b_np)
+
+    prg = cl.Program(ctx, """
+    __kernel void sum(__global const float *a_g, __global const float *b_g,
+        __global float *res_g) {
+      int gid = get_global_id(0);
+      res_g[gid] = a_g[gid] + b_g[gid];
+    }
+    """).build()
+
+    res_g = cl.Buffer(ctx, mf.WRITE_ONLY, a_np.nbytes)
+
+    uevt = cl.UserEvent(ctx)
+
+    evt = prg.sum(queue, a_np.shape, None, a_g, b_g, res_g, wait_for=[uevt])
+
+    evt.set_callback(cl.command_execution_status.COMPLETE, cb)
+
+    uevt.set_status(cl.command_execution_status.COMPLETE)
+
+    queue.finish()
+
+    # yuck
+    from time import sleep
+    sleep(0.1)
+
+    assert got_called
+
+
 if __name__ == "__main__":
     # make sure that import failures get reported, instead of skipping the tests.
     import pyopencl  # noqa
-- 
GitLab