diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index c9a13e96c90e0e77e46db9e1d492f64f22fa6da9..9cc4b4e5f779594a1906d4a20a467edcc2d312c7 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -70,40 +70,47 @@ class PytestPyOpenCLArrayContextFactory: raise NotImplementedError -class _PyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory): +class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFactory): force_device_scalars = True - def __call__(self): + @property + def actx_class(self): from arraycontext import PyOpenCLArrayContext + return PyOpenCLArrayContext + def __call__(self): # The ostensibly pointless assignment to *ctx* keeps the CL context alive # long enough to create the array context, which will then start # holding a reference to the context to keep it alive in turn. # On some implementations (notably Intel CPU), holding a reference # to a queue does not keep the context alive. ctx, queue = self.get_command_queue() - return PyOpenCLArrayContext( + return self.actx_class( queue, force_device_scalars=self.force_device_scalars) def __str__(self): - return ("<PyOpenCLArrayContext for <pyopencl.Device '%s' on '%s'>" % - (self.device.name.strip(), - self.device.platform.name.strip())) + return ("<%s for <pyopencl.Device '%s' on '%s'>" % + ( + self.actx_class.__name__, + self.device.name.strip(), + self.device.platform.name.strip())) -class _DeprecatedPyOpenCLArrayContextFactory(_PyOpenCLArrayContextFactory): +class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars( + _PytestPyOpenCLArrayContextFactoryWithClass): force_device_scalars = False _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = { - "pyopencl": _PyOpenCLArrayContextFactory, - "pyopencl-deprecated": _DeprecatedPyOpenCLArrayContextFactory, + "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, + "pyopencl-deprecated": + _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars, } -def register_array_context_factory( +def register_pytest_array_context_factory( name: str, factory: Type[PytestPyOpenCLArrayContextFactory]) -> None: if name in _ARRAY_CONTEXT_FACTORY_REGISTRY: