diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index c3fecf835aac9f27e7c87df05c5f9a501a9a0712..7022d3217a3d6b5f928698987220960244e74b72 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -52,8 +52,9 @@ class PytatoPyOpenCLArrayContext(ArrayContext): .. automethod:: __init__ """ - def __init__(self, queue, allocator=None): + def __init__(self, queue, allocator=None, force_device_scalars=True): super().__init__() + assert force_device_scalars == True self._force_device_scalars = True self.queue = queue self.allocator = allocator diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index e9614e1dc6539d750e2d286b5c991e93eeb883d7..9171027b9feb6d5702ede0aab8413198550080ef 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -102,18 +102,12 @@ class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars( force_device_scalars = False -class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory): - force_device_scalars = False +class _PytestPytatoPyOpenCLArrayContextFactory(_PytestPyOpenCLArrayContextFactoryWithClass): - def __call__(self): + @property + def actx_class(self): from arraycontext import PytatoPyOpenCLArrayContext - ctx, queue = self.get_command_queue() - return PytatoPyOpenCLArrayContext(queue) - - def __str__(self): - return ("<Pytato array context factory for <pyopencl.Device '%s' on '%s'>" - % (self.device.name.strip(), - self.device.platform.name.strip())) + return PytatoPyOpenCLArrayContext _ARRAY_CONTEXT_FACTORY_REGISTRY: \