Skip to content
Snippets Groups Projects
Commit e215f0df authored by Matthias Diener's avatar Matthias Diener
Browse files

run tests for pytatoarraycontext

parent 7da3a871
No related branches found
No related tags found
No related merge requests found
......@@ -52,7 +52,7 @@ from .container.traversal import (
from .impl.pyopencl import PyOpenCLArrayContext
from .pytest import pytest_generate_tests_for_pyopencl_array_context
from .pytest import pytest_generate_tests_for_array_contexts
from .loopy import make_loopy_program
......
......@@ -31,7 +31,35 @@ THE SOFTWARE.
# {{{ pytest integration
def pytest_generate_tests_for_pyopencl_array_context(metafunc):
import pyopencl as cl
from pyopencl.tools import _ContextFactory
class _PyOpenCLArrayContextFactory(_ContextFactory):
def __call__(self):
ctx = super().__call__()
from arraycontext.impl.pyopencl import PyOpenCLArrayContext
return PyOpenCLArrayContext(cl.CommandQueue(ctx))
def __str__(self):
return ("<PyOpenCL array context factory for <pyopencl.Device '%s' on '%s'>" %
(self.device.name.strip(),
self.device.platform.name.strip()))
class _PytatoArrayContextFactory(_ContextFactory):
def __call__(self):
ctx = super().__call__()
from arraycontext.impl.pytato import PytatoArrayContext
return PytatoArrayContext(cl.CommandQueue(ctx))
def __str__(self):
return ("<Pytato array context factory for <pyopencl.Device '%s' on '%s'>" %
(self.device.name.strip(),
self.device.platform.name.strip()))
def pytest_generate_tests_for_array_contexts(metafunc) -> None:
"""Parametrize tests for pytest to use a
:class:`~arraycontext.PyOpenCLArrayContext`.
......@@ -55,20 +83,6 @@ def pytest_generate_tests_for_pyopencl_array_context(metafunc):
for device selection.
"""
import pyopencl as cl
from pyopencl.tools import _ContextFactory
class ArrayContextFactory(_ContextFactory):
def __call__(self):
ctx = super().__call__()
from arraycontext.impl.pyopencl import PyOpenCLArrayContext
return PyOpenCLArrayContext(cl.CommandQueue(ctx))
def __str__(self):
return ("<array context factory for <pyopencl.Device '%s' on '%s'>" %
(self.device.name.strip(),
self.device.platform.name.strip()))
import pyopencl.tools as cl_tools
arg_names = cl_tools.get_pyopencl_fixture_arg_names(
metafunc, extra_arg_names=["actx_factory"])
......@@ -83,14 +97,20 @@ def pytest_generate_tests_for_pyopencl_array_context(metafunc):
"'ctx_factory' / 'ctx_getter' as arguments.")
for arg_dict in arg_values:
arg_dict["actx_factory"] = ArrayContextFactory(arg_dict["device"])
arg_dict["actx_factory"] = _PyOpenCLArrayContextFactory(arg_dict["device"])
arg_dict["actx_factory_pytato"] = _PytatoArrayContextFactory(arg_dict["device"])
arg_values = [
arg_values_out = [
tuple(arg_dict[name] for name in arg_names)
for arg_dict in arg_values
]
metafunc.parametrize(arg_names, arg_values, ids=ids)
arg_values_out += [
tuple((arg_dict["actx_factory_pytato"],))
for arg_dict in arg_values
]
metafunc.parametrize(arg_names, arg_values_out, ids=ids)
# }}}
......
......@@ -33,7 +33,7 @@ from arraycontext import (
freeze, thaw,
FirstAxisIsElementsTag)
from arraycontext import ( # noqa: F401
pytest_generate_tests_for_pyopencl_array_context
pytest_generate_tests_for_array_contexts
as pytest_generate_tests,
_acf)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment