diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index cfba57f3790ab12a912c92cdc9cb9e793e9a04b6..5af5d1d1c88f26ab3f019cfbf768ece596f96818 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -59,7 +59,7 @@ class _PytatoArrayContextFactory(_ContextFactory): self.device.platform.name.strip())) -def pytest_generate_tests_for_array_contexts(metafunc) -> None: +def pytest_generate_tests_for_array_contexts(metafunc, use_pytato=True) -> None: """Parametrize tests for pytest to use a :class:`~arraycontext.PyOpenCLArrayContext`. @@ -99,15 +99,17 @@ def pytest_generate_tests_for_array_contexts(metafunc) -> None: for arg_dict in arg_values: arg_dict["actx_factory"] = _PyOpenCLArrayContextFactory(arg_dict[ "device"]) - arg_dict["actx_factory_pytato"] = _PytatoArrayContextFactory(arg_dict[ - "device"]) + + if use_pytato: + arg_dict["actx_factory_pytato"] = _PytatoArrayContextFactory( + arg_dict["device"]) arg_values_out = [ tuple(arg_dict[name] for name in arg_names) for arg_dict in arg_values ] - if "actx_factory" in arg_names: + if "actx_factory_pytato" in arg_dict: arg_values_out += [ tuple((arg_dict["actx_factory_pytato"],)) for arg_dict in arg_values @@ -121,7 +123,7 @@ def pytest_generate_tests_for_pyopencl_array_context(metafunc) -> None: warn("'pytato.pytest_generate_tests_for_pyopencl_array_context' " "is deprecated, use 'pytato.pytest_generate_tests_for_array_contexts' " "instead.", DeprecationWarning) - pytest_generate_tests_for_array_contexts(metafunc) + pytest_generate_tests_for_array_contexts(metafunc, use_pytato=False) # }}}