diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 34df04b60aaf57fc98ce8236b76cd0842314d130..243fd03b2be7fbe4e05c01fdb0c0935313bfad1a 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -52,7 +52,7 @@ from .container.traversal import ( from .impl.pyopencl import PyOpenCLArrayContext -from .pytest import pytest_generate_tests_for_pyopencl_array_context +from .pytest import pytest_generate_tests_for_array_contexts from .loopy import make_loopy_program diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 3327aa4cedf1aaa989eca15338771f87c550d3f1..12aef9d57a275397649fe9470b4253a12a71f9f3 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -31,7 +31,35 @@ THE SOFTWARE. # {{{ pytest integration -def pytest_generate_tests_for_pyopencl_array_context(metafunc): +import pyopencl as cl +from pyopencl.tools import _ContextFactory + + +class _PyOpenCLArrayContextFactory(_ContextFactory): + def __call__(self): + ctx = super().__call__() + from arraycontext.impl.pyopencl import PyOpenCLArrayContext + return PyOpenCLArrayContext(cl.CommandQueue(ctx)) + + def __str__(self): + return ("<PyOpenCL array context factory for <pyopencl.Device '%s' on '%s'>" % + (self.device.name.strip(), + self.device.platform.name.strip())) + + +class _PytatoArrayContextFactory(_ContextFactory): + def __call__(self): + ctx = super().__call__() + from arraycontext.impl.pytato import PytatoArrayContext + return PytatoArrayContext(cl.CommandQueue(ctx)) + + def __str__(self): + return ("<Pytato array context factory for <pyopencl.Device '%s' on '%s'>" % + (self.device.name.strip(), + self.device.platform.name.strip())) + + +def pytest_generate_tests_for_array_contexts(metafunc) -> None: """Parametrize tests for pytest to use a :class:`~arraycontext.PyOpenCLArrayContext`. @@ -55,20 +83,6 @@ def pytest_generate_tests_for_pyopencl_array_context(metafunc): for device selection. """ - import pyopencl as cl - from pyopencl.tools import _ContextFactory - - class ArrayContextFactory(_ContextFactory): - def __call__(self): - ctx = super().__call__() - from arraycontext.impl.pyopencl import PyOpenCLArrayContext - return PyOpenCLArrayContext(cl.CommandQueue(ctx)) - - def __str__(self): - return ("<array context factory for <pyopencl.Device '%s' on '%s'>" % - (self.device.name.strip(), - self.device.platform.name.strip())) - import pyopencl.tools as cl_tools arg_names = cl_tools.get_pyopencl_fixture_arg_names( metafunc, extra_arg_names=["actx_factory"]) @@ -83,14 +97,20 @@ def pytest_generate_tests_for_pyopencl_array_context(metafunc): "'ctx_factory' / 'ctx_getter' as arguments.") for arg_dict in arg_values: - arg_dict["actx_factory"] = ArrayContextFactory(arg_dict["device"]) + arg_dict["actx_factory"] = _PyOpenCLArrayContextFactory(arg_dict["device"]) + arg_dict["actx_factory_pytato"] = _PytatoArrayContextFactory(arg_dict["device"]) - arg_values = [ + arg_values_out = [ tuple(arg_dict[name] for name in arg_names) for arg_dict in arg_values ] - metafunc.parametrize(arg_names, arg_values, ids=ids) + arg_values_out += [ + tuple((arg_dict["actx_factory_pytato"],)) + for arg_dict in arg_values + ] + + metafunc.parametrize(arg_names, arg_values_out, ids=ids) # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index f1f78a16b1b7e35bb2570f86cced027ab7b5d286..abd9843fce95534a4ddcf90fe9ca7d20bc87fce7 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -33,7 +33,7 @@ from arraycontext import ( freeze, thaw, FirstAxisIsElementsTag) from arraycontext import ( # noqa: F401 - pytest_generate_tests_for_pyopencl_array_context + pytest_generate_tests_for_array_contexts as pytest_generate_tests, _acf)