From 34c83d315d98b600c7fa82dd48efa537ba8cc589 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Mon, 28 Jun 2021 10:50:29 -0500
Subject: [PATCH] simplify pytato fixture

---
 arraycontext/impl/pytato/__init__.py |  3 ++-
 arraycontext/pytest.py               | 14 ++++----------
 2 files changed, 6 insertions(+), 11 deletions(-)

diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index c3fecf8..7022d32 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -52,8 +52,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
     .. automethod:: __init__
     """
 
-    def __init__(self, queue, allocator=None):
+    def __init__(self, queue, allocator=None, force_device_scalars=True):
         super().__init__()
+        assert force_device_scalars == True
         self._force_device_scalars = True
         self.queue = queue
         self.allocator = allocator
diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py
index e9614e1..9171027 100644
--- a/arraycontext/pytest.py
+++ b/arraycontext/pytest.py
@@ -102,18 +102,12 @@ class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars(
     force_device_scalars = False
 
 
-class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory):
-    force_device_scalars = False
+class _PytestPytatoPyOpenCLArrayContextFactory(_PytestPyOpenCLArrayContextFactoryWithClass):
 
-    def __call__(self):
+    @property
+    def actx_class(self):
         from arraycontext import PytatoPyOpenCLArrayContext
-        ctx, queue = self.get_command_queue()
-        return PytatoPyOpenCLArrayContext(queue)
-
-    def __str__(self):
-        return ("<Pytato array context factory for <pyopencl.Device '%s' on '%s'>"
-                % (self.device.name.strip(),
-                 self.device.platform.name.strip()))
+        return PytatoPyOpenCLArrayContext
 
 
 _ARRAY_CONTEXT_FACTORY_REGISTRY: \
-- 
GitLab