From c658c132d76707a7fcb1fcc4cbfee88467052ffa Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Tue, 9 Jul 2024 20:26:24 +0300 Subject: [PATCH] isort: add jax as a first-party --- arraycontext/impl/jax/fake_numpy.py | 3 ++- arraycontext/impl/pytato/__init__.py | 3 --- arraycontext/pytest.py | 1 - 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 662b86c..6c9192b 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -23,9 +23,10 @@ THE SOFTWARE. """ from functools import partial, reduce -import jax.numpy as jnp import numpy as np +import jax.numpy as jnp + from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( rec_map_array_container, diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 6e04bdc..c030de7 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -720,7 +720,6 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): unstable. """ import jax.numpy as jnp - import pytato as pt super().__init__(compile_trace_callback=compile_trace_callback) self.array_types = (pt.Array, jnp.ndarray) @@ -766,7 +765,6 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): def from_numpy(self, array): import jax - import pytato as pt def _from_numpy(ary): @@ -791,7 +789,6 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): return array import jax.numpy as jnp - import pytato as pt from arraycontext.container.traversal import rec_keyed_map_array_container diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 3ea7d06..d3d719e 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -206,7 +206,6 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory): def is_available(cls) -> bool: try: import jax # noqa: F401 - import pytato # noqa: F401 return True except ImportError: -- GitLab