diff --git a/pyopencl/tools.py b/pyopencl/tools.py index 461b2138e58fbb228798b2a157d75388e38838b4..ae1609a5fd9ee1ecde74f826aac6f9c087884b56 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -172,6 +172,31 @@ atexit.register(clear_first_arg_caches) # }}} +# {{{ pytest fixtures + +class _ContextFactory: + def __init__(self, device): + self.device = device + + def __call__(self): + # Get rid of leftovers from past tests. + # CL implementations are surprisingly limited in how many + # simultaneous contexts they allow... + clear_first_arg_caches() + + from gc import collect + collect() + + import pyopencl as cl + return cl.Context([self.device]) + + def __str__(self): + # Don't show address, so that parallel test collection works + return ("<context factory for <pyopencl.Device '%s' on '%s'>" % + (self.device.name.strip(), + self.device.platform.name.strip())) + + def get_test_platforms_and_devices(plat_dev_string=None): """Parse a string of the form 'PYOPENCL_TEST=0:0,1;intel:i5'. @@ -229,36 +254,17 @@ def get_test_platforms_and_devices(plat_dev_string=None): for platform in cl.get_platforms()] -def pytest_generate_tests_for_pyopencl(metafunc): - import pyopencl as cl - - class ContextFactory: - def __init__(self, device): - self.device = device - - def __call__(self): - # Get rid of leftovers from past tests. - # CL implementations are surprisingly limited in how many - # simultaneous contexts they allow... - - clear_first_arg_caches() - - from gc import collect - collect() +def get_pyopencl_fixture_arg_names(metafunc, extra_arg_names=None): + if extra_arg_names is None: + extra_arg_names = [] - return cl.Context([self.device]) - - def __str__(self): - # Don't show address, so that parallel test collection works - return ("<context factory for <pyopencl.Device '%s' on '%s'>" % - (self.device.name.strip(), - self.device.platform.name.strip())) - - test_plat_and_dev = get_test_platforms_and_devices() + supported_arg_names = [ + "platform", "device", + "ctx_factory", "ctx_getter", + ] + extra_arg_names arg_names = [] - - for arg in ("platform", "device", "ctx_factory", "ctx_getter"): + for arg in supported_arg_names: if arg not in metafunc.fixturenames: continue @@ -270,21 +276,22 @@ def pytest_generate_tests_for_pyopencl(metafunc): arg_names.append(arg) - arg_values = [] + return arg_names - for platform, plat_devs in test_plat_and_dev: - if arg_names == ["platform"]: - arg_values.append((platform,)) - continue +def get_pyopencl_fixture_arg_values(): + import pyopencl as cl + + arg_values = [] + for platform, devices in get_test_platforms_and_devices(): arg_dict = {"platform": platform} - for device in plat_devs: + for device in devices: arg_dict["device"] = device - arg_dict["ctx_factory"] = ContextFactory(device) - arg_dict["ctx_getter"] = ContextFactory(device) + arg_dict["ctx_factory"] = _ContextFactory(device) + arg_dict["ctx_getter"] = _ContextFactory(device) - arg_values.append(tuple(arg_dict[name] for name in arg_names)) + arg_values.append(arg_dict) def idfn(val): if isinstance(val, cl.Platform): @@ -293,8 +300,23 @@ def pytest_generate_tests_for_pyopencl(metafunc): else: return str(val) - if arg_names: - metafunc.parametrize(arg_names, arg_values, ids=idfn) + return arg_values, idfn + + +def pytest_generate_tests_for_pyopencl(metafunc): + arg_names = get_pyopencl_fixture_arg_names(metafunc) + if not arg_names: + return + + arg_values, ids = get_pyopencl_fixture_arg_values() + arg_values = [ + tuple(arg_dict[name] for name in arg_names) + for arg_dict in arg_values + ] + + metafunc.parametrize(arg_names, arg_values, ids=ids) + +# }}} # {{{ C argument lists