diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 7022d3217a3d6b5f928698987220960244e74b72..4d987a4aa566ba5a578b659764efc766c5ce3c74 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -54,12 +54,15 @@ class PytatoPyOpenCLArrayContext(ArrayContext): def __init__(self, queue, allocator=None, force_device_scalars=True): super().__init__() - assert force_device_scalars == True + assert force_device_scalars is True self._force_device_scalars = True self.queue = queue self.allocator = allocator self.np = self._get_fake_numpy_namespace() + # unused, but necessary to keep the context alive + self.context = self.queue.context + def _get_fake_numpy_namespace(self): from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace return PytatoFakeNumpyNamespace(self) diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 9171027b9feb6d5702ede0aab8413198550080ef..de1e005753c4acb0c046eca60e208ff598b7ea60 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -102,7 +102,8 @@ class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars( force_device_scalars = False -class _PytestPytatoPyOpenCLArrayContextFactory(_PytestPyOpenCLArrayContextFactoryWithClass): +class _PytestPytatoPyOpenCLArrayContextFactory( + _PytestPyOpenCLArrayContextFactoryWithClass): @property def actx_class(self):