From d8bd50d99370ed0f5b348652a1506182b67cd24a Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 27 Jun 2021 00:01:12 -0500
Subject: [PATCH] Refactor pytest pyopencl actx factories for easier reuse

---
 arraycontext/pytest.py | 27 +++++++++++++++++----------
 1 file changed, 17 insertions(+), 10 deletions(-)

diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py
index c9a13e9..9cc4b4e 100644
--- a/arraycontext/pytest.py
+++ b/arraycontext/pytest.py
@@ -70,40 +70,47 @@ class PytestPyOpenCLArrayContextFactory:
         raise NotImplementedError
 
 
-class _PyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory):
+class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFactory):
     force_device_scalars = True
 
-    def __call__(self):
+    @property
+    def actx_class(self):
         from arraycontext import PyOpenCLArrayContext
+        return PyOpenCLArrayContext
 
+    def __call__(self):
         # The ostensibly pointless assignment to *ctx* keeps the CL context alive
         # long enough to create the array context, which will then start
         # holding a reference to the context to keep it alive in turn.
         # On some implementations (notably Intel CPU), holding a reference
         # to a queue does not keep the context alive.
         ctx, queue = self.get_command_queue()
-        return PyOpenCLArrayContext(
+        return self.actx_class(
                 queue,
                 force_device_scalars=self.force_device_scalars)
 
     def __str__(self):
-        return ("<PyOpenCLArrayContext for <pyopencl.Device '%s' on '%s'>" %
-                (self.device.name.strip(),
-                 self.device.platform.name.strip()))
+        return ("<%s for <pyopencl.Device '%s' on '%s'>" %
+                (
+                    self.actx_class.__name__,
+                    self.device.name.strip(),
+                    self.device.platform.name.strip()))
 
 
-class _DeprecatedPyOpenCLArrayContextFactory(_PyOpenCLArrayContextFactory):
+class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars(
+        _PytestPyOpenCLArrayContextFactoryWithClass):
     force_device_scalars = False
 
 
 _ARRAY_CONTEXT_FACTORY_REGISTRY: \
         Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = {
-                "pyopencl": _PyOpenCLArrayContextFactory,
-                "pyopencl-deprecated": _DeprecatedPyOpenCLArrayContextFactory,
+                "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
+                "pyopencl-deprecated":
+                _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars,
                 }
 
 
-def register_array_context_factory(
+def register_pytest_array_context_factory(
         name: str,
         factory: Type[PytestPyOpenCLArrayContextFactory]) -> None:
     if name in _ARRAY_CONTEXT_FACTORY_REGISTRY:
-- 
GitLab