diff --git a/pyopencl/tools.py b/pyopencl/tools.py index a0865b62d601ac20387123b2a61d7071793ac4f7..633854e07a523b2cac694b111fabb4b7422ad73c 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -249,11 +249,18 @@ def pytest_generate_tests_for_pyopencl(metafunc): (self.device.name.strip(), self.device.platform.name.strip())) + class QueueFactory: + def __init__(self, ctx_factory): + self.ctx_factory = ctx_factory + + def __call__(self): + return cl.CommandQueue(self.ctx_factory()) + test_plat_and_dev = get_test_platforms_and_devices() arg_names = [] - for arg in ("platform", "device", "cl_queue", "ctx_factory", "ctx_getter"): + for arg in ("platform", "device", "ctx_factory", "cl_queue", "ctx_getter"): if arg not in metafunc.fixturenames: continue @@ -277,11 +284,9 @@ def pytest_generate_tests_for_pyopencl(metafunc): for device in plat_devs: arg_dict["device"] = device arg_dict["ctx_factory"] = ContextFactory(device) + arg_dict["cl_queue"] = QueueFactory(ContextFactory(device)) arg_dict["ctx_getter"] = ContextFactory(device) - context = arg_dict["ctx_factory"]() - arg_dict["cl_queue"] = cl.CommandQueue(context) - arg_values.append(tuple(arg_dict[name] for name in arg_names)) if arg_names: diff --git a/test/test_wrapper.py b/test/test_wrapper.py index 280f414dd30db768fcbbf2df26834f5bd4eaf062..74376b35fefedd5d6a795c6436347a843180e81b 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -534,7 +534,7 @@ def test_copy_buffer(ctx_factory): def test_copy_buffer_using_cl_queue(ctx_factory, cl_queue): context = ctx_factory() - queue = cl_queue + queue = cl_queue() mf = cl.mem_flags a = np.random.rand(50000).astype(np.float32)