diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 4fce588551df1dd7d12ec353e958872e75a691cb..b1bbec952770b89635a76ac92effbc1dda56275f 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -189,7 +189,7 @@ class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory): return False def __call__(self): - from jax.config import config + from jax import config from arraycontext import EagerJAXArrayContext config.update("jax_enable_x64", True) @@ -214,7 +214,7 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory): return False def __call__(self): - from jax.config import config + from jax import config from arraycontext import PytatoJAXArrayContext config.update("jax_enable_x64", True)