From 600c659d05ccc4d65ac8887858ba80a6d5305c44 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Wed, 11 Jan 2023 17:23:00 +0200
Subject: [PATCH] enable and apply isort

---
 arraycontext/__init__.py                      | 73 +++++++------------
 arraycontext/container/__init__.py            |  8 +-
 arraycontext/container/arithmetic.py          |  3 +-
 arraycontext/container/dataclass.py           |  4 +-
 arraycontext/container/traversal.py           | 19 ++---
 arraycontext/context.py                       | 10 ++-
 arraycontext/fake_numpy.py                    |  3 +-
 arraycontext/impl/jax/__init__.py             |  7 +-
 arraycontext/impl/jax/fake_numpy.py           | 14 ++--
 arraycontext/impl/pyopencl/__init__.py        | 11 +--
 arraycontext/impl/pyopencl/fake_numpy.py      | 18 ++---
 .../impl/pyopencl/taggable_cl_array.py        |  4 +-
 arraycontext/impl/pytato/__init__.py          | 55 ++++++++------
 arraycontext/impl/pytato/compile.py           | 38 +++++-----
 arraycontext/impl/pytato/fake_numpy.py        | 17 ++---
 arraycontext/impl/pytato/utils.py             | 14 ++--
 arraycontext/loopy.py                         |  8 +-
 arraycontext/metadata.py                      |  7 +-
 arraycontext/pytest.py                        | 20 +++--
 setup.cfg                                     |  8 ++
 test/test_arraycontext.py                     | 43 +++++------
 test/test_pytato_arraycontext.py              | 14 ++--
 test/test_utils.py                            | 13 ++--
 23 files changed, 205 insertions(+), 206 deletions(-)

diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py
index 06e0b96..00e542d 100644
--- a/arraycontext/__init__.py
+++ b/arraycontext/__init__.py
@@ -29,61 +29,38 @@ THE SOFTWARE.
 """
 
 import sys
-from .context import (
-        ArrayContext,
-
-        Scalar, ScalarLike,
-        Array, ArrayT,
-        ArrayOrContainer, ArrayOrContainerT,
-        ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT,
-
-        tag_axes)
-
-from .transform_metadata import (CommonSubexpressionTag,
-        ElementwiseMapKernelTag)
-
-# deprecated, remove in 2022.
-from .metadata import _FirstAxisIsElementsTag
 
 from .container import (
-        ArrayContainer, ArrayContainerT,
-        NotAnArrayContainerError,
-        is_array_container, is_array_container_type,
-        get_container_context_opt,
-        get_container_context_recursively, get_container_context_recursively_opt,
-        serialize_container, deserialize_container,
-        register_multivector_as_array_container)
+    ArrayContainer, ArrayContainerT, NotAnArrayContainerError, deserialize_container,
+    get_container_context_opt, get_container_context_recursively,
+    get_container_context_recursively_opt, is_array_container,
+    is_array_container_type, register_multivector_as_array_container,
+    serialize_container)
 from .container.arithmetic import with_container_arithmetic
 from .container.dataclass import dataclass_array_container
-
 from .container.traversal import (
-        map_array_container,
-        multimap_array_container,
-        rec_map_array_container,
-        rec_multimap_array_container,
-        mapped_over_array_containers,
-        multimapped_over_array_containers,
-        map_reduce_array_container,
-        multimap_reduce_array_container,
-        rec_map_reduce_array_container,
-        rec_multimap_reduce_array_container,
-        thaw, freeze,
-        flatten, unflatten, flat_size_and_dtype,
-        from_numpy, to_numpy,
-        outer, with_array_context)
-
-from .impl.pyopencl import PyOpenCLArrayContext
-from .impl.pytato import (PytatoPyOpenCLArrayContext,
-                          PytatoJAXArrayContext)
+    flat_size_and_dtype, flatten, freeze, from_numpy, map_array_container,
+    map_reduce_array_container, mapped_over_array_containers,
+    multimap_array_container, multimap_reduce_array_container,
+    multimapped_over_array_containers, outer, rec_map_array_container,
+    rec_map_reduce_array_container, rec_multimap_array_container,
+    rec_multimap_reduce_array_container, thaw, to_numpy, unflatten,
+    with_array_context)
+from .context import (
+    Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar,
+    ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayT, Scalar, ScalarLike,
+    tag_axes)
 from .impl.jax import EagerJAXArrayContext
-
-from .pytest import (
-        PytestArrayContextFactory,
-        PytestPyOpenCLArrayContextFactory,
-        pytest_generate_tests_for_array_contexts,
-        pytest_generate_tests_for_pyopencl_array_context)
-
+from .impl.pyopencl import PyOpenCLArrayContext
+from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
 from .loopy import make_loopy_program
+# deprecated, remove in 2022.
+from .metadata import _FirstAxisIsElementsTag
+from .pytest import (
+    PytestArrayContextFactory, PytestPyOpenCLArrayContextFactory,
+    pytest_generate_tests_for_array_contexts,
+    pytest_generate_tests_for_pyopencl_array_context)
+from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag
 
 
 __all__ = (
diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py
index 71bccee..fcb130f 100644
--- a/arraycontext/container/__init__.py
+++ b/arraycontext/container/__init__.py
@@ -69,17 +69,19 @@ THE SOFTWARE.
 """
 
 from functools import singledispatch
-from arraycontext.context import ArrayContext
-from typing import Any, Iterable, Tuple, Optional, TypeVar, Protocol, TYPE_CHECKING
-import numpy as np
+from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple, TypeVar
 
 # For use in singledispatch type annotations, because sphinx can't figure out
 # what 'np' is.
 import numpy
+import numpy as np
+
+from arraycontext.context import ArrayContext
 
 
 if TYPE_CHECKING:
     from pymbolic.geometric_algebra import MultiVector
+
     from arraycontext import ArrayOrContainer
 
 
diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index 5e2ade2..148d34b 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -7,6 +7,7 @@
 
 import enum
 
+
 __copyright__ = """
 Copyright (C) 2020-1 University of Illinois Board of Trustees
 """
@@ -31,8 +32,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union
 from warnings import warn
-from typing import Any, Callable, Optional, Tuple, TypeVar, Union, Type
 
 import numpy as np
 
diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py
index ee5c2f0..e9ab38d 100644
--- a/arraycontext/container/dataclass.py
+++ b/arraycontext/container/dataclass.py
@@ -30,9 +30,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from dataclasses import Field, fields, is_dataclass
 from typing import Tuple, Union, get_args, get_origin
 
-from dataclasses import Field, is_dataclass, fields
 from arraycontext.container import is_array_container_type
 
 
@@ -95,7 +95,7 @@ def dataclass_array_container(cls: type) -> type:
             # NOTE:
             # * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
             # * `_SpecialForm` catches `Any`, `Literal`, etc.
-            from typing import (                    # type: ignore[attr-defined]
+            from typing import (  # type: ignore[attr-defined]
                 _BaseGenericAlias, _SpecialForm)
             if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)):
                 # NOTE: anything except a Union is not allowed
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index c41b464..048f6a1 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -41,6 +41,7 @@ Algebraic operations
 
 from __future__ import annotations
 
+
 __copyright__ = """
 Copyright (C) 2020-1 University of Illinois Board of Trustees
 """
@@ -65,22 +66,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Any, Callable, Iterable, List, Optional, Union, Tuple, cast
-from functools import update_wrapper, partial, singledispatch
+from functools import partial, singledispatch, update_wrapper
+from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, cast
 from warnings import warn
 
 import numpy as np
 
-from arraycontext.context import (
-    ArrayT, ArrayOrContainer, ArrayOrContainerT,
-    ArrayOrContainerOrScalar, ScalarLike,
-    ArrayContext, Array
-)
 from arraycontext.container import (
-        NotAnArrayContainerError,
-        ArrayContainer,
-        serialize_container, deserialize_container,
-        get_container_context_recursively_opt)
+    ArrayContainer, NotAnArrayContainerError, deserialize_container,
+    get_container_context_recursively_opt, serialize_container)
+from arraycontext.context import (
+    Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar,
+    ArrayOrContainerT, ArrayT, ScalarLike)
 
 
 # {{{ array container traversal helpers
diff --git a/arraycontext/context.py b/arraycontext/context.py
index 2378550..f844106 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -160,15 +160,18 @@ THE SOFTWARE.
 
 from abc import ABC, abstractmethod
 from typing import (
-        Any, Callable, Dict, Optional, Tuple, Union, Mapping, Protocol, TypeVar,
-        TYPE_CHECKING)
+    TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, Protocol, Tuple, TypeVar,
+    Union)
 
 import numpy as np
+
 from pytools import memoize_method
 from pytools.tag import ToTagSetConvertible
 
+
 if TYPE_CHECKING:
     import loopy
+
     from arraycontext.container import ArrayContainer
 
 
@@ -426,8 +429,9 @@ class ArrayContext(ABC):
                         spec: str, arg_names: Tuple[str, ...],
                         tagged: ToTagSetConvertible) -> "loopy.TranslationUnit":
         import loopy as lp
-        from .loopy import _DEFAULT_LOOPY_OPTIONS
         from loopy.version import MOST_RECENT_LANGUAGE_VERSION
+
+        from .loopy import _DEFAULT_LOOPY_OPTIONS
         return lp.make_einsum(
             spec,
             arg_names,
diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py
index c3e37f8..cf416fa 100644
--- a/arraycontext/fake_numpy.py
+++ b/arraycontext/fake_numpy.py
@@ -24,6 +24,7 @@ THE SOFTWARE.
 
 
 import numpy as np
+
 from arraycontext.container import NotAnArrayContainerError, serialize_container
 from arraycontext.container.traversal import rec_map_array_container
 
@@ -105,8 +106,8 @@ class BaseFakeNumpyNamespace:
 # {{{ BaseFakeNumpyLinalgNamespace
 
 def _reduce_norm(actx, arys, ord):
-    from numbers import Number
     from functools import reduce
+    from numbers import Number
 
     if ord is None:
         ord = 2
diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py
index 4aa30c2..e680f7e 100644
--- a/arraycontext/impl/jax/__init__.py
+++ b/arraycontext/impl/jax/__init__.py
@@ -32,9 +32,10 @@ from typing import Callable, Optional, Tuple
 import numpy as np
 
 from pytools.tag import ToTagSetConvertible
-from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike
-from arraycontext.container.traversal import (with_array_context,
-                                              rec_map_array_container)
+
+from arraycontext.container.traversal import (
+    rec_map_array_container, with_array_context)
+from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
 
 
 class EagerJAXArrayContext(ArrayContext):
diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py
index daaf880..0955820 100644
--- a/arraycontext/impl/jax/fake_numpy.py
+++ b/arraycontext/impl/jax/fake_numpy.py
@@ -23,17 +23,15 @@ THE SOFTWARE.
 """
 from functools import partial, reduce
 
-import numpy as np
 import jax.numpy as jnp
+import numpy as np
 
-from arraycontext.fake_numpy import (
-        BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
-        )
-from arraycontext.container.traversal import (
-        rec_multimap_array_container, rec_map_array_container,
-        rec_map_reduce_array_container,
-        )
 from arraycontext.container import NotAnArrayContainerError, serialize_container
+from arraycontext.container.traversal import (
+    rec_map_array_container, rec_map_reduce_array_container,
+    rec_multimap_array_container)
+from arraycontext.fake_numpy import (
+    BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace)
 
 
 class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index aced309..4064689 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -28,21 +28,21 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
 from warnings import warn
-from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
 
 import numpy as np
 
 from pytools.tag import ToTagSetConvertible
 
-from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike
-from arraycontext.container.traversal import (rec_map_array_container,
-                                              with_array_context)
+from arraycontext.container.traversal import (
+    rec_map_array_container, with_array_context)
+from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
 
 
 if TYPE_CHECKING:
-    import pyopencl
     import loopy as lp
+    import pyopencl
 
 
 # {{{ PyOpenCLArrayContext
@@ -287,6 +287,7 @@ class PyOpenCLArrayContext(ArrayContext):
                 wait_event_queue.pop(0).wait()
 
         import arraycontext.impl.pyopencl.taggable_cl_array as tga
+
         # FIXME: Inherit loopy tags for these arrays
         return {name: tga.to_tagged_cl_array(ary) for name, ary in result.items()}
 
diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py
index a0180e7..71af653 100644
--- a/arraycontext/impl/pyopencl/fake_numpy.py
+++ b/arraycontext/impl/pyopencl/fake_numpy.py
@@ -26,24 +26,18 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from functools import partial, reduce
 import operator
+from functools import partial, reduce
 
 import numpy as np
 
-from arraycontext.fake_numpy import (
-        BaseFakeNumpyLinalgNamespace
-        )
-from arraycontext.loopy import (
-        LoopyBasedFakeNumpyNamespace
-        )
 from arraycontext.container import NotAnArrayContainerError, serialize_container
 from arraycontext.container.traversal import (
-        rec_map_array_container,
-        rec_multimap_array_container,
-        rec_map_reduce_array_container,
-        rec_multimap_reduce_array_container,
-        )
+    rec_map_array_container, rec_map_reduce_array_container,
+    rec_multimap_array_container, rec_multimap_reduce_array_container)
+from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
+from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
+
 
 try:
     import pyopencl as cl  # noqa: F401
diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py
index 49ae08b..0361200 100644
--- a/arraycontext/impl/pyopencl/taggable_cl_array.py
+++ b/arraycontext/impl/pyopencl/taggable_cl_array.py
@@ -9,10 +9,10 @@ from dataclasses import dataclass
 from typing import Any, Dict, FrozenSet, Optional, Tuple
 
 import numpy as np
-import pyopencl.array as cla
 
+import pyopencl.array as cla
 from pytools import memoize
-from pytools.tag import Taggable, Tag, ToTagSetConvertible
+from pytools.tag import Tag, Taggable, ToTagSetConvertible
 
 
 # {{{ utils
diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index c3e4462..8fe3055 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -44,27 +44,30 @@ THE SOFTWARE.
 
 import abc
 import sys
-from typing import (Any, Callable, Union, Tuple, Type, FrozenSet, Dict, Optional,
-                    TYPE_CHECKING)
+from typing import (
+    TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Optional, Tuple, Type, Union)
 
 import numpy as np
-from pytools.tag import ToTagSetConvertible, normalize_tags, Tag
 
-from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike
-from arraycontext.container.traversal import (rec_map_array_container,
-                                              with_array_context)
-from arraycontext.metadata import NameHint
 from pytools import memoize_method
+from pytools.tag import Tag, ToTagSetConvertible, normalize_tags
+
+from arraycontext.container.traversal import (
+    rec_map_array_container, with_array_context)
+from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
+from arraycontext.metadata import NameHint
+
 
 if TYPE_CHECKING:
-    import pytato
     import pyopencl as cl
+    import pytato
 
 if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
     import pyopencl as cl  # noqa: F811
 
-
 import logging
+
+
 logger = logging.getLogger(__name__)
 
 
@@ -126,8 +129,8 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC):
         """
         super().__init__()
 
-        import pytato as pt
         import loopy as lp
+        import pytato as pt
         self._freeze_prg_cache: Dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
         self._dag_transform_cache: Dict[
                 pt.DictOfNamedArrays,
@@ -292,8 +295,8 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
             except ImportError:
                 self.using_svm = False
 
-        import pytato as pt
         import pyopencl.array as cla
+        import pytato as pt
         super().__init__(compile_trace_callback=compile_trace_callback)
         self.queue = queue
 
@@ -316,6 +319,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
             default_scalar: Optional[ScalarLike] = None,
             strict: bool = False) -> ArrayOrContainer:
         import pytato as pt
+
         import arraycontext.impl.pyopencl.taggable_cl_array as tga
 
         if allowed_types is None:
@@ -357,6 +361,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
 
     def from_numpy(self, array):
         import pytato as pt
+
         import arraycontext.impl.pyopencl.taggable_cl_array as tga
 
         def _from_numpy(ary):
@@ -415,8 +420,8 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
 
             logger.info(f"limiting argument buffer size for {dev} to {limit} bytes")
 
-            from arraycontext.impl.pytato.utils import \
-                    ArgSizeLimitingPytatoLoopyPyOpenCLTarget
+            from arraycontext.impl.pytato.utils import (
+                ArgSizeLimitingPytatoLoopyPyOpenCLTarget)
             return ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit)
         else:
             return super().get_target()
@@ -425,15 +430,15 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
         if np.isscalar(array):
             return array
 
-        import pytato as pt
         import pyopencl.array as cla
+        import pytato as pt
 
         from arraycontext.container.traversal import rec_keyed_map_array_container
-        from arraycontext.impl.pytato.utils import (_normalize_pt_expr,
-                                                    get_cl_axes_from_pt_axes)
-        from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array,
-                                                                  TaggableCLArray)
+        from arraycontext.impl.pyopencl.taggable_cl_array import (
+            TaggableCLArray, to_tagged_cl_array)
         from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
+        from arraycontext.impl.pytato.utils import (
+            _normalize_pt_expr, get_cl_axes_from_pt_axes)
 
         array_as_dict: Dict[str, Union[cla.Array, TaggableCLArray, pt.Array]] = {}
         key_to_frozen_subary: Dict[str, TaggableCLArray] = {}
@@ -549,8 +554,9 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
 
     def thaw(self, array):
         import pytato as pt
-        from .utils import get_pt_axes_from_cl_axes
+
         import arraycontext.impl.pyopencl.taggable_cl_array as tga
+        from .utils import get_pt_axes_from_cl_axes
 
         def _thaw(ary):
             return pt.make_data_wrapper(ary.with_queue(self.queue),
@@ -579,8 +585,9 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
 
     def call_loopy(self, program, **kwargs):
         import pytato as pt
-        from pytato.scalar_expr import SCALAR_CLASSES
         from pytato.loopy import call_loopy
+        from pytato.scalar_expr import SCALAR_CLASSES
+
         from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
 
         entrypoint = program.default_entrypoint.name
@@ -617,6 +624,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
 
     def einsum(self, spec, *args, arg_names=None, tagged=()):
         import pytato as pt
+
         import arraycontext.impl.pyopencl.taggable_cl_array as tga
 
         if arg_names is None:
@@ -685,8 +693,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
             representation. This interface should be considered
             unstable.
         """
-        import pytato as pt
         import jax.numpy as jnp
+
+        import pytato as pt
         super().__init__(compile_trace_callback=compile_trace_callback)
         self.array_types = (pt.Array, jnp.ndarray)
 
@@ -731,6 +740,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
 
     def from_numpy(self, array):
         import jax
+
         import pytato as pt
 
         def _from_numpy(ary):
@@ -754,9 +764,10 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
         if np.isscalar(array):
             return array
 
+        import jax.numpy as jnp
+
         import pytato as pt
 
-        import jax.numpy as jnp
         from arraycontext.container.traversal import rec_keyed_map_array_container
         from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
 
diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 6fb56cc..d7e0416 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -29,26 +29,26 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from arraycontext.context import ArrayT
-from arraycontext.container import ArrayContainer, is_array_container_type
-from arraycontext.impl.pytato import (_BasePytatoArrayContext,
-                                      PytatoJAXArrayContext,
-                                      PytatoPyOpenCLArrayContext)
-from arraycontext.container.traversal import rec_keyed_map_array_container
-
 import abc
-import numpy as np
-from typing import Any, Callable, Tuple, Dict, Mapping, FrozenSet, Type
+import itertools
+import logging
 from dataclasses import dataclass, field
-from pyrsistent import pmap, PMap
+from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple, Type
+
+import numpy as np
+from pyrsistent import PMap, pmap
 
 import pytato as pt
-import itertools
+from pytools import ProcessLogger
 from pytools.tag import Tag
 
-from pytools import ProcessLogger
+from arraycontext.container import ArrayContainer, is_array_container_type
+from arraycontext.container.traversal import rec_keyed_map_array_container
+from arraycontext.context import ArrayT
+from arraycontext.impl.pytato import (
+    PytatoJAXArrayContext, PytatoPyOpenCLArrayContext, _BasePytatoArrayContext)
+
 
-import logging
 logger = logging.getLogger(__name__)
 
 
@@ -185,8 +185,9 @@ def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext):
       :meth:`PytatoPyOpenCLArrayContext.transform_dag`.
     """
     import pyopencl.array as cla
-    from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array,
-                                                              TaggableCLArray)
+
+    from arraycontext.impl.pyopencl.taggable_cl_array import (
+        TaggableCLArray, to_tagged_cl_array)
     if isinstance(ary, pt.Array):
         dag = pt.make_dict_of_named_arrays({"_actx_out": ary})
         # Transform the DAG to give metadata inference a chance to do its job
@@ -390,9 +391,8 @@ class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller):
         if prg_id is None:
             prg_id = self.f
 
-        from pytato.target.loopy import BoundPyOpenCLProgram
-
         import loopy as lp
+        from pytato.target.loopy import BoundPyOpenCLProgram
 
         self.actx._compile_trace_callback(
                 prg_id, "pre_transform_dag", dict_of_named_arrays)
@@ -632,8 +632,8 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
     output_template: ArrayContainer
 
     def __call__(self, arg_id_to_arg) -> ArrayContainer:
-        from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
         from .utils import get_cl_axes_from_pt_axes
+        from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
 
         input_kwargs_for_loopy = _args_to_device_buffers(
                 self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
@@ -674,8 +674,8 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
     output_name: str
 
     def __call__(self, arg_id_to_arg) -> ArrayContainer:
-        from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
         from .utils import get_cl_axes_from_pt_axes
+        from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array
 
         input_kwargs_for_loopy = _args_to_device_buffers(
                 self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py
index e17f8ee..4dad159 100644
--- a/arraycontext/impl/pytato/fake_numpy.py
+++ b/arraycontext/impl/pytato/fake_numpy.py
@@ -25,19 +25,14 @@ from functools import partial, reduce
 
 import numpy as np
 
-from arraycontext.fake_numpy import (
-        BaseFakeNumpyLinalgNamespace
-        )
-from arraycontext.loopy import (
-        LoopyBasedFakeNumpyNamespace
-        )
+import pytato as pt
+
 from arraycontext.container import NotAnArrayContainerError, serialize_container
 from arraycontext.container.traversal import (
-        rec_map_array_container,
-        rec_multimap_array_container,
-        rec_map_reduce_array_container,
-        )
-import pytato as pt
+    rec_map_array_container, rec_map_reduce_array_container,
+    rec_multimap_array_container)
+from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
+from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
 
 
 class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py
index 9d77202..e0af81c 100644
--- a/arraycontext/impl/pytato/utils.py
+++ b/arraycontext/impl/pytato/utils.py
@@ -23,15 +23,17 @@ THE SOFTWARE.
 """
 
 
-from typing import Any, Dict, Set, Tuple, Mapping, Optional, TYPE_CHECKING
-from pytools import memoize_method
+from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Set, Tuple
 
-from pytato.array import SizeParam, Placeholder, make_placeholder, Axis as PtAxis
-from pytato.array import Array, DataWrapper, DictOfNamedArrays
+from pytato.array import (
+    Array, Axis as PtAxis, DataWrapper, DictOfNamedArrays, Placeholder, SizeParam,
+    make_placeholder)
+from pytato.target.loopy import LoopyPyOpenCLTarget
 from pytato.transform import CopyMapper
-from pytools import UniqueNameGenerator
+from pytools import UniqueNameGenerator, memoize_method
+
 from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis
-from pytato.target.loopy import LoopyPyOpenCLTarget
+
 
 if TYPE_CHECKING:
     import loopy as lp
diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py
index 1f90318..f8e54b5 100644
--- a/arraycontext/loopy.py
+++ b/arraycontext/loopy.py
@@ -28,12 +28,15 @@ THE SOFTWARE.
 """
 
 import numpy as np
+
 import loopy as lp
 from loopy.version import MOST_RECENT_LANGUAGE_VERSION
-from arraycontext.fake_numpy import BaseFakeNumpyNamespace
-from arraycontext.container.traversal import multimapped_over_array_containers
 from pytools import memoize_in
 
+from arraycontext.container.traversal import multimapped_over_array_containers
+from arraycontext.fake_numpy import BaseFakeNumpyNamespace
+
+
 # {{{ loopy
 
 _DEFAULT_LOOPY_OPTIONS = lp.Options(
@@ -89,6 +92,7 @@ def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes):
         domain_bset, = domain.get_basic_sets()
 
         import loopy as lp
+
         from .loopy import make_loopy_program
         from arraycontext.transform_metadata import ElementwiseMapKernelTag
         return make_loopy_program(
diff --git a/arraycontext/metadata.py b/arraycontext/metadata.py
index 39934d6..95fc639 100644
--- a/arraycontext/metadata.py
+++ b/arraycontext/metadata.py
@@ -29,9 +29,10 @@ THE SOFTWARE.
 
 import sys
 from dataclasses import dataclass
-from pytools.tag import Tag, UniqueTag
 from warnings import warn
 
+from pytools.tag import Tag, UniqueTag
+
 
 @dataclass(frozen=True)
 class NameHint(UniqueTag):
@@ -52,8 +53,8 @@ class NameHint(UniqueTag):
 # {{{ deprecation handling
 
 try:
-    from meshmode.transform_metadata import FirstAxisIsElementsTag \
-            as _FirstAxisIsElementsTag
+    from meshmode.transform_metadata import (
+        FirstAxisIsElementsTag as _FirstAxisIsElementsTag)
 except ImportError:
     # placeholder in case meshmode is too old to have it.
     class _FirstAxisIsElementsTag(Tag):  # type: ignore[no-redef]
diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py
index f84915b..4fce588 100644
--- a/arraycontext/pytest.py
+++ b/arraycontext/pytest.py
@@ -63,7 +63,7 @@ class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory):
     @classmethod
     def is_available(cls) -> bool:
         try:
-            import pyopencl         # noqa: F401
+            import pyopencl  # noqa: F401
             return True
         except ImportError:
             return False
@@ -79,6 +79,7 @@ class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory):
         collect()
 
         import pyopencl as cl
+
         # On Intel CPU CL, existence of a command queue does not ensure that
         # the context survives.
         ctx = cl.Context([self.device])
@@ -134,8 +135,8 @@ class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory
     @classmethod
     def is_available(cls) -> bool:
         try:
-            import pyopencl         # noqa: F401
-            import pytato           # noqa: F401
+            import pyopencl  # noqa: F401
+            import pytato  # noqa: F401
             return True
         except ImportError:
             return False
@@ -182,14 +183,15 @@ class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory):
     @classmethod
     def is_available(cls) -> bool:
         try:
-            import jax              # noqa: F401
+            import jax  # noqa: F401
             return True
         except ImportError:
             return False
 
     def __call__(self):
-        from arraycontext import EagerJAXArrayContext
         from jax.config import config
+
+        from arraycontext import EagerJAXArrayContext
         config.update("jax_enable_x64", True)
         return EagerJAXArrayContext()
 
@@ -204,15 +206,17 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory):
     @classmethod
     def is_available(cls) -> bool:
         try:
-            import jax              # noqa: F401
-            import pytato           # noqa: F401
+            import jax  # noqa: F401
+
+            import pytato  # noqa: F401
             return True
         except ImportError:
             return False
 
     def __call__(self):
-        from arraycontext import PytatoJAXArrayContext
         from jax.config import config
+
+        from arraycontext import PytatoJAXArrayContext
         config.update("jax_enable_x64", True)
         return PytatoJAXArrayContext()
 
diff --git a/setup.cfg b/setup.cfg
index b24271f..60eab3b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -8,6 +8,14 @@ docstring-quotes = """
 multiline-quotes = """
 # enable-flake8-bugbear
 
+[isort]
+known_firstparty=pytools,pyopencl,pymbolic,islpy,loopy,pytato
+known_local_folder=arraycontext
+line_length = 85
+lines_after_imports = 2
+combine_as_imports = True
+multi_line_output = 4
+
 [mypy]
 # it reads pytato code, and pytato is 3.8+
 python_version = 3.8
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index f373b77..ebc8dc4 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -20,6 +20,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+import logging
 from dataclasses import dataclass
 from typing import Union
 
@@ -28,26 +29,16 @@ import pytest
 
 from pytools.obj_array import make_obj_array
 
-from arraycontext import (
-        ArrayContext,
-        dataclass_array_container, with_container_arithmetic,
-        serialize_container, deserialize_container, with_array_context,
-        FirstAxisIsElementsTag,
-        PyOpenCLArrayContext,
-        PytatoPyOpenCLArrayContext,
-        EagerJAXArrayContext,
-        ArrayContainer,
-        tag_axes)
 from arraycontext import (  # noqa: F401
-        pytest_generate_tests_for_array_contexts,
-        )
-from arraycontext.pytest import (_PytestPyOpenCLArrayContextFactoryWithClass,
-                                 _PytestPytatoPyOpenCLArrayContextFactory,
-                                 _PytestEagerJaxArrayContextFactory,
-                                 _PytestPytatoJaxArrayContextFactory)
+    ArrayContainer, ArrayContext, EagerJAXArrayContext, FirstAxisIsElementsTag,
+    PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container,
+    deserialize_container, pytest_generate_tests_for_array_contexts,
+    serialize_container, tag_axes, with_array_context, with_container_arithmetic)
+from arraycontext.pytest import (
+    _PytestEagerJaxArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass,
+    _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory)
 
 
-import logging
 logger = logging.getLogger(__name__)
 
 
@@ -497,8 +488,9 @@ def test_dof_array_arithmetic_same_as_numpy(actx_factory):
         return ary.imag
 
     import operator
+    from random import randrange, uniform
+
     from pytools import generate_nonnegative_integer_tuples_below as gnitb
-    from random import uniform, randrange
     for op_func, n_args, use_integers in [
             (operator.add, 2, False),
             (operator.sub, 2, False),
@@ -775,9 +767,8 @@ def test_container_map_on_device_scalar(actx_factory):
     arys += (np.pi,)
 
     from arraycontext import (
-            map_array_container, rec_map_array_container,
-            map_reduce_array_container, rec_map_reduce_array_container,
-            )
+        map_array_container, map_reduce_array_container, rec_map_array_container,
+        rec_map_reduce_array_container)
 
     for size, ary in zip(expected_sizes, arys[:-1]):
         result = map_array_container(lambda x: x, ary)
@@ -921,6 +912,7 @@ def test_container_arithmetic(actx_factory):
         assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
 
     from functools import partial
+
     from arraycontext import rec_multimap_array_container
     for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
         rec_multimap_array_container(
@@ -973,8 +965,7 @@ def test_container_freeze_thaw(actx_factory):
     # {{{ check
 
     from arraycontext import (
-            get_container_context_opt,
-            get_container_context_recursively_opt)
+        get_container_context_opt, get_container_context_recursively_opt)
 
     assert get_container_context_opt(ary_of_dofs) is None
     assert get_container_context_opt(mat_of_dofs) is None
@@ -1077,7 +1068,7 @@ def test_flatten_array_container(actx_factory, shapes):
 
 
 def _checked_flatten(ary, actx, leaf_class=None):
-    from arraycontext import flatten, flat_size_and_dtype
+    from arraycontext import flat_size_and_dtype, flatten
     result = flatten(ary, actx, leaf_class=leaf_class)
 
     if leaf_class is None:
@@ -1279,7 +1270,7 @@ def test_actx_compile_kwargs(actx_factory):
 def test_actx_compile_with_tuple_output_keys(actx_factory):
     # arraycontext.git<=3c9aee68 would fail due to a bug in output
     # key stringification logic.
-    from arraycontext import (to_numpy, from_numpy)
+    from arraycontext import from_numpy, to_numpy
     actx = actx_factory()
 
     def my_rhs(scale, vel):
@@ -1361,6 +1352,7 @@ def test_leaf_array_type_broadcasting(actx_factory):
             return True
         else:
             import pyopencl as cl
+
             # See https://github.com/inducer/pyopencl/issues/498
             return cl.version.VERSION > (2021, 2, 5)
 
@@ -1573,6 +1565,7 @@ def test_tagging(actx_factory):
 
 def test_compile_anonymous_function(actx_factory):
     from functools import partial
+
     # See https://github.com/inducer/grudge/issues/287
     actx = actx_factory()
     f = actx.compile(lambda x: 2*x+40)
diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py
index fd3c6ea..eea1144 100644
--- a/test/test_pytato_arraycontext.py
+++ b/test/test_pytato_arraycontext.py
@@ -22,13 +22,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from arraycontext import PytatoPyOpenCLArrayContext
-from arraycontext import pytest_generate_tests_for_array_contexts
-from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory
-from pytools.tag import Tag
+import logging
 
 import pytest
-import logging
+
+from pytools.tag import Tag
+
+from arraycontext import (
+    PytatoPyOpenCLArrayContext, pytest_generate_tests_for_array_contexts)
+from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory
+
+
 logger = logging.getLogger(__name__)
 
 
diff --git a/test/test_utils.py b/test/test_utils.py
index 5edf66b..94f6b0a 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -22,11 +22,12 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
-import pytest
+import logging
 
 import numpy as np
+import pytest
+
 
-import logging
 logger = logging.getLogger(__name__)
 
 
@@ -47,8 +48,9 @@ def test_pt_actx_key_stringification_uniqueness():
 # {{{ test_dataclass_array_container
 
 def test_dataclass_array_container() -> None:
-    from typing import Optional
     from dataclasses import dataclass, field
+    from typing import Optional
+
     from arraycontext import dataclass_array_container
 
     # {{{ string fields
@@ -111,10 +113,9 @@ def test_dataclass_array_container() -> None:
 
 def test_dataclass_container_unions() -> None:
     from dataclasses import dataclass
-    from arraycontext import dataclass_array_container
-
     from typing import Union
-    from arraycontext import Array
+
+    from arraycontext import Array, dataclass_array_container
 
     # {{{ union fields
 
-- 
GitLab