diff --git a/pyopencl/tools.py b/pyopencl/tools.py index 6a72ee9796d8b728d3bc919633a1da27a891c584..7304cfc73dbcf722223e5bb2da1510dd909de6dc 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -249,38 +249,38 @@ def pytest_generate_tests_for_pyopencl(metafunc): test_plat_and_dev = get_test_platforms_and_devices() - if ("device" in metafunc.funcargnames - or "ctx_factory" in metafunc.funcargnames - or "ctx_getter" in metafunc.funcargnames): - arg_dict = {} - - for platform, plat_devs in test_plat_and_dev: - if "platform" in metafunc.funcargnames: - arg_dict["platform"] = platform - - for device in plat_devs: - if "device" in metafunc.funcargnames: - arg_dict["device"] = device - - if "ctx_factory" in metafunc.funcargnames: - arg_dict["ctx_factory"] = ContextFactory(device) - - if "ctx_getter" in metafunc.funcargnames: - from warnings import warn - warn("The 'ctx_getter' arg is deprecated in " - "favor of 'ctx_factory'.", - DeprecationWarning) - arg_dict["ctx_getter"] = ContextFactory(device) - - metafunc.addcall(funcargs=arg_dict.copy(), - id=", ".join("%s=%s" % (arg, value) - for arg, value in six.iteritems(arg_dict))) - - elif "platform" in metafunc.funcargnames: - for platform, plat_devs in test_plat_and_dev: - metafunc.addcall( - funcargs=dict(platform=platform), - id=str(platform)) + arg_names = [] + + for arg in ("platform", "device", "ctx_factory", "ctx_getter"): + if arg not in metafunc.funcargnames: + continue + + if arg == "ctx_getter": + from warnings import warn + warn("The 'ctx_getter' arg is deprecated in " + "favor of 'ctx_factory'.", + DeprecationWarning) + + arg_names.append(arg) + + arg_values = [] + + for platform, plat_devs in test_plat_and_dev: + if arg_names == ["platform"]: + arg_values.append((platform,)) + continue + + arg_dict = {"platform": platform} + + for device in plat_devs: + arg_dict["device"] = device + arg_dict["ctx_factory"] = ContextFactory(device) + arg_dict["ctx_getter"] = ContextFactory(device) + + arg_values.append(tuple(arg_dict[name] for name in arg_names)) + + if arg_names: + metafunc.parametrize(arg_names, arg_values, ids=str) # {{{ C argument lists