diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index e680f7eb022f780d785dc9b1ca1290e4f9031207..e5fef3edac9241fd9c283f2f8de2f89488406716 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -41,7 +41,7 @@ from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLi class EagerJAXArrayContext(ArrayContext): """ A :class:`ArrayContext` that uses - :class:`jaxlib.xla_extension.DeviceArrayBase` instances for its base array + :class:`jax.Array` instances for its base array class and performs all array operations eagerly. See :class:`~arraycontext.PytatoJAXArrayContext` for a lazier version.