Skip to content
Snippets Groups Projects
Unverified Commit 792abc8c authored by Andreas Klöckner's avatar Andreas Klöckner Committed by GitHub
Browse files

Merge pull request #359 from alexfikl/refactor-pytest-fixture

Refactor pytest_generate_tests_for_pyopencl
parents d4bbc613 50d95c54
No related branches found
No related tags found
No related merge requests found
......@@ -172,6 +172,31 @@ atexit.register(clear_first_arg_caches)
# }}}
# {{{ pytest fixtures
class _ContextFactory:
def __init__(self, device):
self.device = device
def __call__(self):
# Get rid of leftovers from past tests.
# CL implementations are surprisingly limited in how many
# simultaneous contexts they allow...
clear_first_arg_caches()
from gc import collect
collect()
import pyopencl as cl
return cl.Context([self.device])
def __str__(self):
# Don't show address, so that parallel test collection works
return ("<context factory for <pyopencl.Device '%s' on '%s'>" %
(self.device.name.strip(),
self.device.platform.name.strip()))
def get_test_platforms_and_devices(plat_dev_string=None):
"""Parse a string of the form 'PYOPENCL_TEST=0:0,1;intel:i5'.
......@@ -229,36 +254,17 @@ def get_test_platforms_and_devices(plat_dev_string=None):
for platform in cl.get_platforms()]
def pytest_generate_tests_for_pyopencl(metafunc):
import pyopencl as cl
class ContextFactory:
def __init__(self, device):
self.device = device
def __call__(self):
# Get rid of leftovers from past tests.
# CL implementations are surprisingly limited in how many
# simultaneous contexts they allow...
clear_first_arg_caches()
from gc import collect
collect()
def get_pyopencl_fixture_arg_names(metafunc, extra_arg_names=None):
if extra_arg_names is None:
extra_arg_names = []
return cl.Context([self.device])
def __str__(self):
# Don't show address, so that parallel test collection works
return ("<context factory for <pyopencl.Device '%s' on '%s'>" %
(self.device.name.strip(),
self.device.platform.name.strip()))
test_plat_and_dev = get_test_platforms_and_devices()
supported_arg_names = [
"platform", "device",
"ctx_factory", "ctx_getter",
] + extra_arg_names
arg_names = []
for arg in ("platform", "device", "ctx_factory", "ctx_getter"):
for arg in supported_arg_names:
if arg not in metafunc.fixturenames:
continue
......@@ -270,21 +276,22 @@ def pytest_generate_tests_for_pyopencl(metafunc):
arg_names.append(arg)
arg_values = []
return arg_names
for platform, plat_devs in test_plat_and_dev:
if arg_names == ["platform"]:
arg_values.append((platform,))
continue
def get_pyopencl_fixture_arg_values():
import pyopencl as cl
arg_values = []
for platform, devices in get_test_platforms_and_devices():
arg_dict = {"platform": platform}
for device in plat_devs:
for device in devices:
arg_dict["device"] = device
arg_dict["ctx_factory"] = ContextFactory(device)
arg_dict["ctx_getter"] = ContextFactory(device)
arg_dict["ctx_factory"] = _ContextFactory(device)
arg_dict["ctx_getter"] = _ContextFactory(device)
arg_values.append(tuple(arg_dict[name] for name in arg_names))
arg_values.append(arg_dict)
def idfn(val):
if isinstance(val, cl.Platform):
......@@ -293,8 +300,23 @@ def pytest_generate_tests_for_pyopencl(metafunc):
else:
return str(val)
if arg_names:
metafunc.parametrize(arg_names, arg_values, ids=idfn)
return arg_values, idfn
def pytest_generate_tests_for_pyopencl(metafunc):
arg_names = get_pyopencl_fixture_arg_names(metafunc)
if not arg_names:
return
arg_values, ids = get_pyopencl_fixture_arg_values()
arg_values = [
tuple(arg_dict[name] for name in arg_names)
for arg_dict in arg_values
]
metafunc.parametrize(arg_names, arg_values, ids=ids)
# }}}
# {{{ C argument lists
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment