diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 2500e61f3b74d761b1e6b6ca18d5c504e0ad1844..2a4b5be902bf2e859553a643daaba8a8682ac693 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -51,10 +51,8 @@ class PytatoPyOpenCLArrayContext(ArrayContext): .. automethod:: __init__ """ - def __init__(self, queue, allocator=None, force_device_scalars=True): + def __init__(self, queue, allocator=None): super().__init__() - assert force_device_scalars is True - self._force_device_scalars = True self.queue = queue self.allocator = allocator self.np = self._get_fake_numpy_namespace() diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index de1e005753c4acb0c046eca60e208ff598b7ea60..6f56144c44df062b5973bcb6a6ed9dcbe2b1ccab 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -103,13 +103,23 @@ class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars( class _PytestPytatoPyOpenCLArrayContextFactory( - _PytestPyOpenCLArrayContextFactoryWithClass): + PytestPyOpenCLArrayContextFactory): @property def actx_class(self): from arraycontext import PytatoPyOpenCLArrayContext return PytatoPyOpenCLArrayContext + def __call__(self): + # The ostensibly pointless assignment to *ctx* keeps the CL context alive + # long enough to create the array context, which will then start + # holding a reference to the context to keep it alive in turn. + # On some implementations (notably Intel CPU), holding a reference + # to a queue does not keep the context alive. + ctx, queue = self.get_command_queue() + return self.actx_class( + queue) + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = { diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 5f34113cdd1dbb2a853da923955050f2a6d9bfa1..9f194abd48c30f68c9563aaae0750ab0b876aa94 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -456,10 +456,10 @@ def test_dof_array_reductions_same_as_numpy(actx_factory, op): from numbers import Number - if actx._force_device_scalars: - assert actx_red.shape == () - else: + if isinstance(actx, PyOpenCLArrayContext) and (not actx._force_device_scalars): assert isinstance(actx_red, Number) + else: + assert actx_red.shape == () assert np.allclose(np_red, actx_red)