From 8a1f66a21427a838f99c33e13140b43dccfe087b Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Fri, 1 Mar 2024 15:13:31 +0200 Subject: [PATCH] jax: update config import --- arraycontext/pytest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 4fce588..b1bbec9 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -189,7 +189,7 @@ class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory): return False def __call__(self): - from jax.config import config + from jax import config from arraycontext import EagerJAXArrayContext config.update("jax_enable_x64", True) @@ -214,7 +214,7 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory): return False def __call__(self): - from jax.config import config + from jax import config from arraycontext import PytatoJAXArrayContext config.update("jax_enable_x64", True) -- GitLab