diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index db5dcb054435d81e9e6456490070fdea6cec29b0..77e51feb8e5e0b8fb02c042a05d2eed1be702aa7 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -100,8 +100,9 @@ class _PytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory): force_device_scalars = False def __call__(self): - from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext - return PytatoPyOpenCLArrayContext(self.get_command_queue()) + from arraycontext import PytatoPyOpenCLArrayContext + ctx, queue = self.get_command_queue() + return PytatoPyOpenCLArrayContext(queue) def __str__(self): return ("<Pytato array context factory for <pyopencl.Device '%s' on '%s'>"