diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 1b82254eae66117d16fb622faeeaf81427a3965a..0f274a75eec4e02c20b5644e83980ebbbfd0ac64 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -31,19 +31,46 @@ from arraycontext import ( dataclass_array_container, with_container_arithmetic, serialize_container, deserialize_container, freeze, thaw, - FirstAxisIsElementsTag) + FirstAxisIsElementsTag, + PyOpenCLArrayContext) from arraycontext import ( # noqa: F401 pytest_generate_tests_for_array_contexts, _acf) +from arraycontext.pytest import _PytestPyOpenCLArrayContextFactoryWithClass + import logging logger = logging.getLogger(__name__) +# {{{ array context fixture + +class _PyOpenCLArrayContextForTests(PyOpenCLArrayContext): + """Like :class:`PyOpenCLArrayContext`, but applies no program transformations + whatsoever. Only to be used for testing internal to :mod:`arraycontext`. + """ + + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PyOpenCLArrayContextWithHostScalarsForTestsFactory( + _PytestPyOpenCLArrayContextFactoryWithClass): + actx_class = _PyOpenCLArrayContextForTests + + +class _PyOpenCLArrayContextForTestsFactory( + _PyOpenCLArrayContextWithHostScalarsForTestsFactory): + force_device_scalars = True + + pytest_generate_tests = pytest_generate_tests_for_array_contexts([ - "pyopencl", "pyopencl-deprecated", + _PyOpenCLArrayContextForTestsFactory, + _PyOpenCLArrayContextWithHostScalarsForTestsFactory, ]) +# }}} + # {{{ stand-in DOFArray implementation