diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 662b86c05a19e3d40f9448a9ffb82b01dc1bd3ee..6c9192ba7ebb979b726ccf53454922c60784da4f 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -23,9 +23,10 @@ THE SOFTWARE. """ from functools import partial, reduce -import jax.numpy as jnp import numpy as np +import jax.numpy as jnp + from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( rec_map_array_container, diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 6e04bdcdeae931429ffbf4a64d5091b9b1c5c196..c030de7188b484d69c00e8fcf29bf9fc8668f694 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -720,7 +720,6 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): unstable. """ import jax.numpy as jnp - import pytato as pt super().__init__(compile_trace_callback=compile_trace_callback) self.array_types = (pt.Array, jnp.ndarray) @@ -766,7 +765,6 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): def from_numpy(self, array): import jax - import pytato as pt def _from_numpy(ary): @@ -791,7 +789,6 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): return array import jax.numpy as jnp - import pytato as pt from arraycontext.container.traversal import rec_keyed_map_array_container diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 3ea7d065d644dbe6b02ba3cfb3605558f6647973..d3d719e515c573ae618d5005c86b7feca6c0c160 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -206,7 +206,6 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory): def is_available(cls) -> bool: try: import jax # noqa: F401 - import pytato # noqa: F401 return True except ImportError: