From 0de1de1c61d945d411efb8826c9909ca5998ffac Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Mon, 28 Jun 2021 12:51:04 -0500 Subject: [PATCH] PytatoPyOpenCLArrayContext: shouldn't refer to _force_device_scalars --- arraycontext/impl/pytato/__init__.py | 4 +--- arraycontext/pytest.py | 12 +++++++++++- test/test_arraycontext.py | 6 +++--- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 2500e61..2a4b5be 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -51,10 +51,8 @@ class PytatoPyOpenCLArrayContext(ArrayContext): .. automethod:: __init__ """ - def __init__(self, queue, allocator=None, force_device_scalars=True): + def __init__(self, queue, allocator=None): super().__init__() - assert force_device_scalars is True - self._force_device_scalars = True self.queue = queue self.allocator = allocator self.np = self._get_fake_numpy_namespace() diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index de1e005..6f56144 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -103,13 +103,23 @@ class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars( class _PytestPytatoPyOpenCLArrayContextFactory( - _PytestPyOpenCLArrayContextFactoryWithClass): + PytestPyOpenCLArrayContextFactory): @property def actx_class(self): from arraycontext import PytatoPyOpenCLArrayContext return PytatoPyOpenCLArrayContext + def __call__(self): + # The ostensibly pointless assignment to *ctx* keeps the CL context alive + # long enough to create the array context, which will then start + # holding a reference to the context to keep it alive in turn. + # On some implementations (notably Intel CPU), holding a reference + # to a queue does not keep the context alive. + ctx, queue = self.get_command_queue() + return self.actx_class( + queue) + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestPyOpenCLArrayContextFactory]] = { diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 5f34113..9f194ab 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -456,10 +456,10 @@ def test_dof_array_reductions_same_as_numpy(actx_factory, op): from numbers import Number - if actx._force_device_scalars: - assert actx_red.shape == () - else: + if isinstance(actx, PyOpenCLArrayContext) and (not actx._force_device_scalars): assert isinstance(actx_red, Number) + else: + assert actx_red.shape == () assert np.allclose(np_red, actx_red) -- GitLab