diff --git a/test/test_jax.py b/test/test_jax.py index dc93857fadf4ea1103a705800a8cf93299c61f1d..5214735e17a3f0a67b6510881bf295c355944931 100644 --- a/test/test_jax.py +++ b/test/test_jax.py @@ -23,6 +23,8 @@ THE SOFTWARE. import pytest import numpy as np import pytato as pt + +pytest.importorskip("jax") from jax.config import config config.update("jax_enable_x64", True)