diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 243fd03b2be7fbe4e05c01fdb0c0935313bfad1a..35c313fb8772d095ccab09b95532aa21b7809526 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -51,6 +51,7 @@ from .container.traversal import ( from_numpy, to_numpy) from .impl.pyopencl import PyOpenCLArrayContext +from .impl.pytato import PytatoArrayContext from .pytest import pytest_generate_tests_for_array_contexts @@ -77,7 +78,7 @@ __all__ = ( "thaw", "freeze", "from_numpy", "to_numpy", - "PyOpenCLArrayContext", + "PyOpenCLArrayContext", "PytatoArrayContext", "make_loopy_program", diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 85387a6697209892b3ed344c593d19d8b3cfa33a..5bad85c74658de62d55ada95d441603b76449d95 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -1,6 +1,6 @@ """ .. currentmodule:: arraycontext -.. autofunction:: pytest_generate_tests_for_pyopencl_array_context +.. autofunction:: pytest_generate_tests_for_array_contexts """ @@ -42,8 +42,8 @@ class _PyOpenCLArrayContextFactory(_ContextFactory): return PyOpenCLArrayContext(cl.CommandQueue(ctx)) def __str__(self): - return ("" % - (self.device.name.strip(), + return ("" + % (self.device.name.strip(), self.device.platform.name.strip())) @@ -54,8 +54,8 @@ class _PytatoArrayContextFactory(_ContextFactory): return PytatoArrayContext(cl.CommandQueue(ctx)) def __str__(self): - return ("" % - (self.device.name.strip(), + return ("" + % (self.device.name.strip(), self.device.platform.name.strip())) @@ -97,8 +97,10 @@ def pytest_generate_tests_for_array_contexts(metafunc) -> None: "'ctx_factory' / 'ctx_getter' as arguments.") for arg_dict in arg_values: - arg_dict["actx_factory"] = _PyOpenCLArrayContextFactory(arg_dict["device"]) - arg_dict["actx_factory_pytato"] = _PytatoArrayContextFactory(arg_dict["device"]) + arg_dict["actx_factory"] = _PyOpenCLArrayContextFactory(arg_dict[ + "device"]) + arg_dict["actx_factory_pytato"] = _PytatoArrayContextFactory(arg_dict[ + "device"]) arg_values_out = [ tuple(arg_dict[name] for name in arg_names)