diff --git a/pyopencl/tools.py b/pyopencl/tools.py index 02b43ca87e113b8e08590c18f886339b8a9c53c3..edbdd03d906ee33ead45a46191350287a63652ee 100644 --- a/pyopencl/tools.py +++ b/pyopencl/tools.py @@ -106,28 +106,15 @@ atexit.register(clear_first_arg_caches) -def pytest_generate_tests_for_pyopencl(metafunc): - class ContextFactory: - def __init__(self, device): - self.device = device - - def __call__(self): - # Get rid of leftovers from past tests. - # CL implementations are surprisingly limited in how many - # simultaneous contexts they allow... - - clear_first_arg_caches() - - from gc import collect - collect() - - return cl.Context([self.device]) +def get_test_platforms_and_devices(plat_dev_string=None): + """Parse a string of the form 'PYOPENCL_TEST=0:0,1;intel:i5'. - def __str__(self): - return "<context factory for %s>" % self.device + :return: list of tuples (platform, [device, device, ...]) + """ - import os - dev_string = os.environ.get("PYOPENCL_TEST", None) + if plat_dev_string is None: + import os + plat_dev_string = os.environ.get("PYOPENCL_TEST", None) def find_cl_obj(objs, identifier): try: @@ -144,17 +131,15 @@ def pytest_generate_tests_for_pyopencl(metafunc): if not found: raise RuntimeError("object '%s' not found" % identifier) - if dev_string: - # PYOPENCL_TEST=0:0,1;intel:i5 - - test_plat_and_dev = [] # list of tuples (platform, [device, device, ...]) + if plat_dev_string: + result = [] - for entry in dev_string.split(";"): + for entry in plat_dev_string.split(";"): lhsrhs = entry.split(":") if len(lhsrhs) == 1: platform = find_cl_obj(cl.get_platforms(), lhsrhs[0]) - test_plat_and_dev.append((platform, platform.get_devices())) + result.append((platform, platform.get_devices())) elif len(lhsrhs) != 2: raise RuntimeError("invalid syntax of PYOPENCL_TEST") @@ -163,13 +148,41 @@ def pytest_generate_tests_for_pyopencl(metafunc): platform = find_cl_obj(cl.get_platforms(), plat_str) devs = platform.get_devices() - test_plat_and_dev.append( + result.append( (platform, [find_cl_obj(devs, dev_id) for dev_id in dev_strs.split(",")])) + + return result + else: - test_plat_and_dev = [ + return [ (platform, platform.get_devices()) for platform in cl.get_platforms()] + + + +def pytest_generate_tests_for_pyopencl(metafunc): + class ContextFactory: + def __init__(self, device): + self.device = device + + def __call__(self): + # Get rid of leftovers from past tests. + # CL implementations are surprisingly limited in how many + # simultaneous contexts they allow... + + clear_first_arg_caches() + + from gc import collect + collect() + + return cl.Context([self.device]) + + def __str__(self): + return "<context factory for %s>" % self.device + + test_plat_and_dev = get_test_platforms_and_devices() + if ("device" in metafunc.funcargnames or "ctx_factory" in metafunc.funcargnames or "ctx_getter" in metafunc.funcargnames): diff --git a/test/test_wrapper.py b/test/test_wrapper.py index be682dd0cd5c471bf37a708815b032e0cc38fe47..156ac4cd983c7a667dbe1da5b6f847dff61fb0b8 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -153,24 +153,22 @@ class TestCL: lambda info: img.get_image_info(info)) @pytools.test.mark_test.opencl - def test_invalid_kernel_names_cause_failures(self): - for platform in cl.get_platforms(): - for device in platform.get_devices(): - ctx = cl.Context([device]) - prg = cl.Program(ctx, """ - __kernel void sum(__global float *a) - { a[get_global_id(0)] *= 2; } - """).build() - - try: - prg.sam - raise RuntimeError("invalid kernel name did not cause error") - except AttributeError: - pass - except RuntimeError: - raise RuntimeError("weird exception from OpenCL implementation " - "on invalid kernel name--are you using " - "Intel's implementation?") + def test_invalid_kernel_names_cause_failures(self, device): + ctx = cl.Context([device]) + prg = cl.Program(ctx, """ + __kernel void sum(__global float *a) + { a[get_global_id(0)] *= 2; } + """).build() + + try: + prg.sam + raise RuntimeError("invalid kernel name did not cause error") + except AttributeError: + pass + except RuntimeError: + raise RuntimeError("weird exception from OpenCL implementation " + "on invalid kernel name--are you using " + "Intel's implementation?") @pytools.test.mark_test.opencl def test_image_format_constructor(self):