From 2bb87e0f7d886dfb86523cf08b269cad0c0b79fc Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Mon, 22 Aug 2022 22:21:48 -0500
Subject: [PATCH] Fix, test capture_call

---
 pyopencl/capture_call.py |  2 +-
 test/test_wrapper.py     | 33 +++++++++++++++++++++++++++++++++
 2 files changed, 34 insertions(+), 1 deletion(-)

diff --git a/pyopencl/capture_call.py b/pyopencl/capture_call.py
index c0648b21..11535f58 100644
--- a/pyopencl/capture_call.py
+++ b/pyopencl/capture_call.py
@@ -146,7 +146,7 @@ def capture_kernel_call(kernel, output_file, queue, g_size, l_size, *args, **kwa
     for name, val in arg_data:
         cg("%s = (" % name)
         with Indentation(cg):
-            val = str(b64encode(compress(memoryview(val))))
+            val = b64encode(compress(memoryview(val))).decode()
             i = 0
             while i < len(val):
                 cg(repr(val[i:i+line_len]))
diff --git a/test/test_wrapper.py b/test/test_wrapper.py
index 22f27f74..0ec3e134 100644
--- a/test/test_wrapper.py
+++ b/test/test_wrapper.py
@@ -1266,6 +1266,39 @@ def test_command_queue_context_manager(ctx_factory):
         q.flush()
 
 
+def test_capture_call(ctx_factory):
+    ctx = ctx_factory()
+    queue = cl.CommandQueue(ctx)
+
+    a_np = np.random.rand(500).astype(np.float32)
+    b_np = np.random.rand(500).astype(np.float32)
+
+    ctx = cl.create_some_context()
+    queue = cl.CommandQueue(ctx)
+
+    mf = cl.mem_flags
+    a_g = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=a_np)
+    b_g = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=b_np)
+
+    prg = cl.Program(ctx, """
+    __kernel void sum(
+        __global const float *a_g, __global const float *b_g, __global float *res_g)
+    {
+    int gid = get_global_id(0);
+    res_g[gid] = a_g[gid] + b_g[gid];
+    }
+    """).build()
+
+    res_g = cl.Buffer(ctx, mf.WRITE_ONLY, a_np.nbytes)
+    from io import StringIO
+    sio = StringIO()
+    prg.sum.capture_call(sio, queue, a_np.shape, None, a_g, b_g, res_g)
+
+    compile_dict = {}
+    exec(compile(sio.getvalue(), "captured.py", "exec"), compile_dict)
+    compile_dict["main"]()
+
+
 if __name__ == "__main__":
     # make sure that import failures get reported, instead of skipping the tests.
     import pyopencl  # noqa
-- 
GitLab