From d8bd50d99370ed0f5b348652a1506182b67cd24a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 27 Jun 2021 00:01:12 -0500 Subject: [PATCH] Refactor pytest pyopencl actx factories for easier reuse --- arraycontext/pytest.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index c9a13e9..9cc4b4e 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -70,40 +70,47 @@ class PytestPyOpenCLArrayContextFactory: raise NotImplementedError -class _PyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory): +class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFactory): force_device_scalars = True - def __call__(self): + @property + def actx_class(self): from arraycontext import PyOpenCLArrayContext + return PyOpenCLArrayContext + def __call__(self): # The ostensibly pointless assignment to *ctx* keeps the CL context alive # long enough to create the array context, which will then start # holding a reference to the context to keep it alive in turn. # On some implementations (notably Intel CPU), holding a reference # to a queue does not keep the context alive. ctx, queue = self.get_command_queue() - return PyOpenCLArrayContext( + return self.actx_class( queue, force_device_scalars=self.force_device_scalars) def __str__(self): - return ("<PyOpenCLArrayContext for <pyopencl.Device '%s' on '%s'>" % - (self.device.name.strip(), - self.device.platform.name.strip())) + return ("<%s for <pyopencl.Device '%s' on '%s'>" % + ( + self.actx_class.__name__, + self.device.name.strip(), + self.device.platform.name.strip())) -class _DeprecatedPyOpenCLArrayContextFactory(_PyOpenCLArrayContextFactory): +class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars( + _PytestPyOpenCLArrayContextFactoryWithClass): force_device_scalars = False _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = { - "pyopencl": _PyOpenCLArrayContextFactory, - "pyopencl-deprecated": _DeprecatedPyOpenCLArrayContextFactory, + "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, + "pyopencl-deprecated": + _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars, } -def register_array_context_factory( +def register_pytest_array_context_factory( name: str, factory: Type[PytestPyOpenCLArrayContextFactory]) -> None: if name in _ARRAY_CONTEXT_FACTORY_REGISTRY: -- GitLab