From 0de1de1c61d945d411efb8826c9909ca5998ffac Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Mon, 28 Jun 2021 12:51:04 -0500
Subject: [PATCH] PytatoPyOpenCLArrayContext: shouldn't refer to
 _force_device_scalars

---
 arraycontext/impl/pytato/__init__.py |  4 +---
 arraycontext/pytest.py               | 12 +++++++++++-
 test/test_arraycontext.py            |  6 +++---
 3 files changed, 15 insertions(+), 7 deletions(-)

diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 2500e61..2a4b5be 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -51,10 +51,8 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
     .. automethod:: __init__
     """
 
-    def __init__(self, queue, allocator=None, force_device_scalars=True):
+    def __init__(self, queue, allocator=None):
         super().__init__()
-        assert force_device_scalars is True
-        self._force_device_scalars = True
         self.queue = queue
         self.allocator = allocator
         self.np = self._get_fake_numpy_namespace()
diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py
index de1e005..6f56144 100644
--- a/arraycontext/pytest.py
+++ b/arraycontext/pytest.py
@@ -103,13 +103,23 @@ class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars(
 
 
 class _PytestPytatoPyOpenCLArrayContextFactory(
-        _PytestPyOpenCLArrayContextFactoryWithClass):
+        PytestPyOpenCLArrayContextFactory):
 
     @property
     def actx_class(self):
         from arraycontext import PytatoPyOpenCLArrayContext
         return PytatoPyOpenCLArrayContext
 
+    def __call__(self):
+        # The ostensibly pointless assignment to *ctx* keeps the CL context alive
+        # long enough to create the array context, which will then start
+        # holding a reference to the context to keep it alive in turn.
+        # On some implementations (notably Intel CPU), holding a reference
+        # to a queue does not keep the context alive.
+        ctx, queue = self.get_command_queue()
+        return self.actx_class(
+                queue)
+
 
 _ARRAY_CONTEXT_FACTORY_REGISTRY: \
         Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = {
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 5f34113..9f194ab 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -456,10 +456,10 @@ def test_dof_array_reductions_same_as_numpy(actx_factory, op):
 
     from numbers import Number
 
-    if actx._force_device_scalars:
-        assert actx_red.shape == ()
-    else:
+    if isinstance(actx, PyOpenCLArrayContext) and (not actx._force_device_scalars):
         assert isinstance(actx_red, Number)
+    else:
+        assert actx_red.shape == ()
 
     assert np.allclose(np_red, actx_red)
 
-- 
GitLab