From 2bb87e0f7d886dfb86523cf08b269cad0c0b79fc Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Mon, 22 Aug 2022 22:21:48 -0500 Subject: [PATCH] Fix, test capture_call --- pyopencl/capture_call.py | 2 +- test/test_wrapper.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/pyopencl/capture_call.py b/pyopencl/capture_call.py index c0648b21..11535f58 100644 --- a/pyopencl/capture_call.py +++ b/pyopencl/capture_call.py @@ -146,7 +146,7 @@ def capture_kernel_call(kernel, output_file, queue, g_size, l_size, *args, **kwa for name, val in arg_data: cg("%s = (" % name) with Indentation(cg): - val = str(b64encode(compress(memoryview(val)))) + val = b64encode(compress(memoryview(val))).decode() i = 0 while i < len(val): cg(repr(val[i:i+line_len])) diff --git a/test/test_wrapper.py b/test/test_wrapper.py index 22f27f74..0ec3e134 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -1266,6 +1266,39 @@ def test_command_queue_context_manager(ctx_factory): q.flush() +def test_capture_call(ctx_factory): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + a_np = np.random.rand(500).astype(np.float32) + b_np = np.random.rand(500).astype(np.float32) + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + mf = cl.mem_flags + a_g = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=a_np) + b_g = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=b_np) + + prg = cl.Program(ctx, """ + __kernel void sum( + __global const float *a_g, __global const float *b_g, __global float *res_g) + { + int gid = get_global_id(0); + res_g[gid] = a_g[gid] + b_g[gid]; + } + """).build() + + res_g = cl.Buffer(ctx, mf.WRITE_ONLY, a_np.nbytes) + from io import StringIO + sio = StringIO() + prg.sum.capture_call(sio, queue, a_np.shape, None, a_g, b_g, res_g) + + compile_dict = {} + exec(compile(sio.getvalue(), "captured.py", "exec"), compile_dict) + compile_dict["main"]() + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests. import pyopencl # noqa -- GitLab