diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 24557ed5df9581c5c93d57c687c82bd2b31b409d..0d534c95bb5dbd1e0f1af86d34a16870bffd74f0 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -31,6 +31,8 @@ THE SOFTWARE. # {{{ pytest integration +from arraycontext.impl.pyopencl import PyOpenCLArrayContext +from arraycontext.impl.pytato import PytatoArrayContext import pyopencl as cl from pyopencl.tools import _ContextFactory @@ -59,9 +61,13 @@ class _PytatoArrayContextFactory(_ContextFactory): self.device.platform.name.strip())) -def pytest_generate_tests_for_array_contexts(metafunc, use_pytato=True) -> None: +types_to_factories = {PyOpenCLArrayContext: _PyOpenCLArrayContextFactory, + PytatoArrayContext: _PytatoArrayContextFactory} + + +def pytest_generate_tests_for_array_contexts(metafunc, actx_list=None) -> None: """Parametrize tests for pytest to use a - :class:`~arraycontext.PyOpenCLArrayContext`. + :class:`~arraycontext.ArrayContext`. Performs device enumeration analogously to :func:`pyopencl.tools.pytest_generate_tests_for_pyopencl`. @@ -70,7 +76,7 @@ def pytest_generate_tests_for_array_contexts(metafunc, use_pytato=True) -> None: .. code-block:: python - from arraycontext import pytest_generate_tests_for_pyopencl + from arraycontext import pytest_generate_tests_for_array_contexts as pytest_generate_tests in your pytest test scripts allows you to use the argument ``actx_factory``, @@ -83,6 +89,11 @@ def pytest_generate_tests_for_array_contexts(metafunc, use_pytato=True) -> None: for device selection. """ + if actx_list is None: + actx_list = [PyOpenCLArrayContext, PytatoArrayContext] + + actx_factories = [types_to_factories[a] for a in actx_list] + import pyopencl.tools as cl_tools arg_names = cl_tools.get_pyopencl_fixture_arg_names( metafunc, extra_arg_names=["actx_factory"]) @@ -97,21 +108,24 @@ def pytest_generate_tests_for_array_contexts(metafunc, use_pytato=True) -> None: "'ctx_factory' / 'ctx_getter' as arguments.") for arg_dict in arg_values: - arg_dict["actx_factory"] = _PyOpenCLArrayContextFactory(arg_dict[ - "device"]) + dev = arg_dict["device"] + extra_factories = [] - if use_pytato: - arg_dict["actx_factory_pytato"] = _PytatoArrayContextFactory( - arg_dict["device"]) + for factory in actx_factories: + if "actx_factory" in arg_dict: + arg_dict["actx_factory_"+str(factory)] = factory(dev) + extra_factories += ("actx_factory_"+str(factory),) + else: + arg_dict["actx_factory"] = factory(dev) arg_values_out = [ tuple(arg_dict[name] for name in arg_names) for arg_dict in arg_values ] - if "actx_factory_pytato" in arg_dict: + for extra_factory in extra_factories: arg_values_out += [ - tuple((arg_dict["actx_factory_pytato"],)) + tuple((arg_dict[extra_factory],)) for arg_dict in arg_values ] @@ -122,9 +136,10 @@ def pytest_generate_tests_for_pyopencl_array_context(metafunc) -> None: from warnings import warn warn("'pytato.pytest_generate_tests_for_pyopencl_array_context' " "is deprecated, use 'pytato.pytest_generate_tests_for_array_contexts' " - "instead. PytatoArrayContext tests will be disabled.", + "instead. Only tests with PyOpenCLArrayContext will be run.", DeprecationWarning) - pytest_generate_tests_for_array_contexts(metafunc, use_pytato=False) + pytest_generate_tests_for_array_contexts(metafunc, + actx_list=[PyOpenCLArrayContext]) # }}}