From 29d74437afba7b6cc27f58adb06a43c804603528 Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Tue, 5 Nov 2024 09:44:37 +0200
Subject: [PATCH] ruff: fix type import errors

---
 arraycontext/container/__init__.py            | 21 ++----
 arraycontext/container/arithmetic.py          | 39 +++++-----
 arraycontext/container/traversal.py           | 61 ++++++++--------
 arraycontext/context.py                       | 28 +++-----
 arraycontext/impl/jax/__init__.py             |  6 +-
 arraycontext/impl/numpy/fake_numpy.py         |  7 +-
 arraycontext/impl/pyopencl/__init__.py        | 17 ++---
 .../impl/pyopencl/taggable_cl_array.py        | 28 ++++----
 arraycontext/impl/pytato/__init__.py          | 72 +++++++++----------
 arraycontext/impl/pytato/compile.py           | 59 +++++++--------
 arraycontext/impl/pytato/utils.py             | 14 ++--
 arraycontext/loopy.py                         |  3 +-
 arraycontext/pytest.py                        | 22 +++---
 arraycontext/version.py                       |  3 +-
 test/test_arraycontext.py                     |  6 +-
 15 files changed, 179 insertions(+), 207 deletions(-)

diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py
index 655a3e6..6c4fb67 100644
--- a/arraycontext/container/__init__.py
+++ b/arraycontext/container/__init__.py
@@ -79,18 +79,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from collections.abc import Hashable, Sequence
 from functools import singledispatch
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Hashable,
-    Optional,
-    Protocol,
-    Sequence,
-    Tuple,
-    TypeAlias,
-    TypeVar,
-)
+from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar
 
 # For use in singledispatch type annotations, because sphinx can't figure out
 # what 'np' is.
@@ -162,7 +153,7 @@ class NotAnArrayContainerError(TypeError):
 
 
 SerializationKey: TypeAlias = Hashable
-SerializedContainer: TypeAlias = Sequence[Tuple[SerializationKey, "ArrayOrContainer"]]
+SerializedContainer: TypeAlias = Sequence[tuple[SerializationKey, "ArrayOrContainer"]]
 
 
 @singledispatch
@@ -249,7 +240,7 @@ def is_array_container(ary: Any) -> bool:
 
 
 @singledispatch
-def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]:
+def get_container_context_opt(ary: ArrayContainer) -> ArrayContext | None:
     """Retrieves the :class:`ArrayContext` from the container, if any.
 
     This function is not recursive, so it will only search at the root level
@@ -303,7 +294,7 @@ def _deserialize_ndarray_container(  # type: ignore[misc]
 # {{{ get_container_context_recursively
 
 def get_container_context_recursively_opt(
-        ary: ArrayContainer) -> Optional[ArrayContext]:
+        ary: ArrayContainer) -> ArrayContext | None:
     """Walks the :class:`ArrayContainer` hierarchy to find an
     :class:`ArrayContext` associated with it.
 
@@ -337,7 +328,7 @@ def get_container_context_recursively_opt(
         return actx
 
 
-def get_container_context_recursively(ary: ArrayContainer) -> Optional[ArrayContext]:
+def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | None:
     """Walks the :class:`ArrayContainer` hierarchy to find an
     :class:`ArrayContext` associated with it.
 
diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py
index 9366b26..72e39a1 100644
--- a/arraycontext/container/arithmetic.py
+++ b/arraycontext/container/arithmetic.py
@@ -34,7 +34,8 @@ THE SOFTWARE.
 """
 
 import enum
-from typing import Any, Callable, Optional, Tuple, TypeVar, Union
+from collections.abc import Callable
+from typing import Any, TypeVar
 from warnings import warn
 
 import numpy as np
@@ -90,7 +91,7 @@ _BINARY_OP_AND_DUNDER = [
         ]
 
 
-def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:
+def _format_unary_op_str(op_str: str, arg1: tuple[str, ...] | str) -> str:
     if isinstance(arg1, tuple):
         arg1_entry, arg1_container = arg1
         return (f"{op_str.format(arg1_entry)} "
@@ -100,20 +101,14 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:
 
 
 def _format_binary_op_str(op_str: str,
-        arg1: Union[Tuple[str, str], str],
-        arg2: Union[Tuple[str, str], str]) -> str:
+        arg1: tuple[str, str] | str,
+        arg2: tuple[str, str] | str) -> str:
     if isinstance(arg1, tuple) and isinstance(arg2, tuple):
-        import sys
-        if sys.version_info >= (3, 10):
-            strict_arg = ", strict=__debug__"
-        else:
-            strict_arg = ""
-
         arg1_entry, arg1_container = arg1
         arg2_entry, arg2_container = arg2
         return (f"{op_str.format(arg1_entry, arg2_entry)} "
                 f"for {arg1_entry}, {arg2_entry} "
-                f"in zip({arg1_container}, {arg2_container}{strict_arg})")
+                f"in zip({arg1_container}, {arg2_container}, strict=__debug__)")
 
     elif isinstance(arg1, tuple):
         arg1_entry, arg1_container = arg1
@@ -160,23 +155,23 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet
 
 def with_container_arithmetic(
             *,
-            number_bcasts_across: Optional[bool] = None,
-            bcasts_across_obj_array: Optional[bool] = None,
-            container_types_bcast_across: Optional[Tuple[type, ...]] = None,
+            number_bcasts_across: bool | None = None,
+            bcasts_across_obj_array: bool | None = None,
+            container_types_bcast_across: tuple[type, ...] | None = None,
             arithmetic: bool = True,
             matmul: bool = False,
             bitwise: bool = False,
             shift: bool = False,
-            _cls_has_array_context_attr: Optional[bool] = None,
-            eq_comparison: Optional[bool] = None,
-            rel_comparison: Optional[bool] = None,
+            _cls_has_array_context_attr: bool | None = None,
+            eq_comparison: bool | None = None,
+            rel_comparison: bool | None = None,
 
             # deprecated:
-            bcast_number: Optional[bool] = None,
-            bcast_obj_array: Optional[bool] = None,
+            bcast_number: bool | None = None,
+            bcast_obj_array: bool | None = None,
             bcast_numpy_array: bool = False,
-            _bcast_actx_array_type: Optional[bool] = None,
-            bcast_container_types: Optional[Tuple[type, ...]] = None,
+            _bcast_actx_array_type: bool | None = None,
+            bcast_container_types: tuple[type, ...] | None = None,
         ) -> Callable[[type], type]:
     """A class decorator that implements built-in operators for array containers
     by propagating the operations to the elements of the container.
@@ -482,7 +477,7 @@ def with_container_arithmetic(
             assert k1 == k2
             return k1
 
-        def tup_str(t: Tuple[str, ...]) -> str:
+        def tup_str(t: tuple[str, ...]) -> str:
             if not t:
                 return "()"
             else:
diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 5f94ad6..80f38af 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -70,8 +70,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from collections.abc import Callable, Iterable
 from functools import partial, singledispatch, update_wrapper
-from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, cast
+from typing import Any, cast
 from warnings import warn
 
 import numpy as np
@@ -100,7 +101,7 @@ from arraycontext.context import (
 def _map_array_container_impl(
         f: Callable[[ArrayOrContainer], ArrayOrContainer],
         ary: ArrayOrContainer, *,
-        leaf_cls: Optional[type] = None,
+        leaf_cls: type | None = None,
         recursive: bool = False) -> ArrayOrContainer:
     """Helper for :func:`rec_map_array_container`.
 
@@ -129,9 +130,9 @@ def _map_array_container_impl(
 def _multimap_array_container_impl(
         f: Callable[..., Any],
         *args: Any,
-        reduce_func: Optional[Callable[
-            [ArrayContainer, Iterable[Tuple[Any, Any]]], Any]] = None,
-        leaf_cls: Optional[type] = None,
+        reduce_func: (
+            Callable[[ArrayContainer, Iterable[tuple[Any, Any]]], Any] | None) = None,
+        leaf_cls: type | None = None,
         recursive: bool = False) -> ArrayOrContainer:
     """Helper for :func:`rec_multimap_array_container`.
 
@@ -183,7 +184,7 @@ def _multimap_array_container_impl(
 
     # {{{ find all containers in the argument list
 
-    container_indices: List[int] = []
+    container_indices: list[int] = []
 
     for i, arg in enumerate(args):
         if type(arg) is leaf_cls:
@@ -244,7 +245,7 @@ def stringify_array_container_tree(ary: ArrayOrContainer) -> str:
     :returns: a string for an ASCII tree representation of the array container,
         similar to `asciitree <https://github.com/mbr/asciitree>`__.
     """
-    def rec(lines: List[str], ary_: ArrayOrContainerT, level: int) -> None:
+    def rec(lines: list[str], ary_: ArrayOrContainerT, level: int) -> None:
         try:
             iterable = serialize_container(ary_)
         except NotAnArrayContainerError:
@@ -307,7 +308,7 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
 def rec_map_array_container(
         f: Callable[[Any], Any],
         ary: ArrayOrContainer,
-        leaf_class: Optional[type] = None) -> ArrayOrContainer:
+        leaf_class: type | None = None) -> ArrayOrContainer:
     r"""Applies *f* recursively to an :class:`ArrayContainer`.
 
     For a non-recursive version see :func:`map_array_container`.
@@ -319,12 +320,12 @@ def rec_map_array_container(
 
 
 def mapped_over_array_containers(
-        f: Optional[Callable[[ArrayOrContainer], ArrayOrContainer]] = None,
-        leaf_class: Optional[type] = None) -> Union[
-            Callable[[ArrayOrContainer], ArrayOrContainer],
-            Callable[
+        f: Callable[[ArrayOrContainer], ArrayOrContainer] | None = None,
+        leaf_class: type | None = None) -> (
+            Callable[[ArrayOrContainer], ArrayOrContainer]
+            | Callable[
                 [Callable[[Any], Any]],
-                Callable[[ArrayOrContainer], ArrayOrContainer]]]:
+                Callable[[ArrayOrContainer], ArrayOrContainer]]):
     """Decorator around :func:`rec_map_array_container`."""
     def decorator(g: Callable[[ArrayOrContainer], ArrayOrContainer]) -> Callable[
             [ArrayOrContainer], ArrayOrContainer]:
@@ -340,7 +341,7 @@ def mapped_over_array_containers(
 def rec_multimap_array_container(
         f: Callable[..., Any],
         *args: Any,
-        leaf_class: Optional[type] = None) -> Any:
+        leaf_class: type | None = None) -> Any:
     r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
 
     For a non-recursive version see :func:`multimap_array_container`.
@@ -353,10 +354,10 @@ def rec_multimap_array_container(
 
 
 def multimapped_over_array_containers(
-        f: Optional[Callable[..., Any]] = None,
-        leaf_class: Optional[type] = None) -> Union[
-            Callable[..., Any],
-            Callable[[Callable[..., Any]], Callable[..., Any]]]:
+        f: Callable[..., Any] | None = None,
+        leaf_class: type | None = None) -> (
+            Callable[..., Any]
+            | Callable[[Callable[..., Any]], Callable[..., Any]]):
     """Decorator around :func:`rec_multimap_array_container`."""
     def decorator(g: Callable[..., Any]) -> Callable[..., Any]:
         # can't use functools.partial, because its result is insufficiently
@@ -403,7 +404,7 @@ def keyed_map_array_container(
 
 
 def rec_keyed_map_array_container(
-        f: Callable[[Tuple[SerializationKey, ...], ArrayT], ArrayT],
+        f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT],
         ary: ArrayOrContainer) -> ArrayOrContainer:
     """
     Works similarly to :func:`rec_map_array_container`, except that *f* also
@@ -412,7 +413,7 @@ def rec_keyed_map_array_container(
     the current array.
     """
 
-    def rec(keys: Tuple[SerializationKey, ...],
+    def rec(keys: tuple[SerializationKey, ...],
             _ary: ArrayOrContainerT) -> ArrayOrContainerT:
         try:
             iterable = serialize_container(_ary)
@@ -469,7 +470,7 @@ def multimap_reduce_array_container(
     # NOTE: this wrapper matches the signature of `deserialize_container`
     # to make plugging into `_multimap_array_container_impl` easier
     def _reduce_wrapper(
-            ary: ArrayContainer, iterable: Iterable[Tuple[Any, Any]]
+            ary: ArrayContainer, iterable: Iterable[tuple[Any, Any]]
             ) -> Array:
         return reduce_func([subary for _, subary in iterable])
 
@@ -482,7 +483,7 @@ def rec_map_reduce_array_container(
         reduce_func: Callable[[Iterable[Any]], Any],
         map_func: Callable[[Any], Any],
         ary: ArrayOrContainer,
-        leaf_class: Optional[type] = None) -> ArrayOrContainer:
+        leaf_class: type | None = None) -> ArrayOrContainer:
     """Perform a map-reduce over array containers recursively.
 
     :param reduce_func: callable used to reduce over the components of *ary*
@@ -540,7 +541,7 @@ def rec_multimap_reduce_array_container(
         reduce_func: Callable[[Iterable[Any]], Any],
         map_func: Callable[..., Any],
         *args: Any,
-        leaf_class: Optional[type] = None) -> ArrayOrContainer:
+        leaf_class: type | None = None) -> ArrayOrContainer:
     r"""Perform a map-reduce over multiple array containers recursively.
 
     :param reduce_func: callable used to reduce over the components of any
@@ -559,7 +560,7 @@ def rec_multimap_reduce_array_container(
     # NOTE: this wrapper matches the signature of `deserialize_container`
     # to make plugging into `_multimap_array_container_impl` easier
     def _reduce_wrapper(
-            ary: ArrayContainer, iterable: Iterable[Tuple[Any, Any]]) -> Any:
+            ary: ArrayContainer, iterable: Iterable[tuple[Any, Any]]) -> Any:
         return reduce_func([subary for _, subary in iterable])
 
     return _multimap_array_container_impl(
@@ -573,7 +574,7 @@ def rec_multimap_reduce_array_container(
 
 def freeze(
         ary: ArrayOrContainerT,
-        actx: Optional[ArrayContext] = None) -> ArrayOrContainerT:
+        actx: ArrayContext | None = None) -> ArrayOrContainerT:
     r"""Freezes recursively by going through all components of the
     :class:`ArrayContainer` *ary*.
 
@@ -650,7 +651,7 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
 
 @singledispatch
 def with_array_context(ary: ArrayOrContainerT,
-                       actx: Optional[ArrayContext]) -> ArrayOrContainerT:
+                       actx: ArrayContext | None) -> ArrayOrContainerT:
     """
     Recursively associates *actx* to all the components of *ary*.
 
@@ -674,7 +675,7 @@ def with_array_context(ary: ArrayOrContainerT,
 
 def flatten(
         ary: ArrayOrContainer, actx: ArrayContext, *,
-        leaf_class: Optional[type] = None,
+        leaf_class: type | None = None,
         ) -> Any:
     """Convert all arrays in the :class:`~arraycontext.ArrayContainer`
     into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`.
@@ -696,7 +697,7 @@ def flatten(
     """
     common_dtype = None
 
-    def _flatten(subary: ArrayOrContainer) -> List[Array]:
+    def _flatten(subary: ArrayOrContainer) -> list[Array]:
         nonlocal common_dtype
 
         try:
@@ -874,7 +875,7 @@ def unflatten(
 
 
 def flat_size_and_dtype(
-        ary: ArrayOrContainer) -> Tuple[int, Optional[np.dtype[Any]]]:
+        ary: ArrayOrContainer) -> tuple[int, np.dtype[Any] | None]:
     """
     :returns: a tuple ``(size, dtype)`` that would be the length and
         :class:`numpy.dtype` of the one-dimensional array returned by
@@ -910,7 +911,7 @@ def flat_size_and_dtype(
 # {{{ numpy conversion
 
 def from_numpy(
-        ary: Union[np.ndarray, ScalarLike],
+        ary: np.ndarray | ScalarLike,
         actx: ArrayContext) -> ArrayOrContainerOrScalar:
     """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer`
     to the base array type of :class:`~arraycontext.ArrayContext`.
diff --git a/arraycontext/context.py b/arraycontext/context.py
index ee989ef..398f8aa 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -159,18 +159,8 @@ THE SOFTWARE.
 """
 
 from abc import ABC, abstractmethod
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Callable,
-    Dict,
-    Mapping,
-    Optional,
-    Protocol,
-    Tuple,
-    TypeVar,
-    Union,
-)
+from collections.abc import Callable, Mapping
+from typing import TYPE_CHECKING, Any, Protocol, TypeVar, Union
 from warnings import warn
 
 import numpy as np
@@ -187,7 +177,7 @@ if TYPE_CHECKING:
 
 # {{{ typing
 
-ScalarLike = Union[int, float, complex, np.generic]
+ScalarLike = int | float | complex | np.generic
 
 SelfType = TypeVar("SelfType")
 
@@ -206,7 +196,7 @@ class Array(Protocol):
     """
 
     @property
-    def shape(self) -> Tuple[int, ...]:
+    def shape(self) -> tuple[int, ...]:
         ...
 
     @property
@@ -291,7 +281,7 @@ class ArrayContext(ABC):
     .. automethod:: compile
     """
 
-    array_types: Tuple[type, ...] = ()
+    array_types: tuple[type, ...] = ()
 
     def __init__(self) -> None:
         self.np = self._get_fake_numpy_namespace()
@@ -304,7 +294,7 @@ class ArrayContext(ABC):
         raise TypeError(f"unhashable type: '{type(self).__name__}'")
 
     def zeros(self,
-              shape: Union[int, Tuple[int, ...]],
+              shape: int | tuple[int, ...],
               dtype: "np.dtype[Any]") -> Array:
         warn(f"{type(self).__name__}.zeros is deprecated and will stop "
             "working in 2025. Use actx.np.zeros instead.",
@@ -340,7 +330,7 @@ class ArrayContext(ABC):
     @abstractmethod
     def call_loopy(self,
                    t_unit: "loopy.TranslationUnit",
-                   **kwargs: Any) -> Dict[str, Array]:
+                   **kwargs: Any) -> dict[str, Array]:
         """Execute the :mod:`loopy` program *program* on the arguments
         *kwargs*.
 
@@ -423,7 +413,7 @@ class ArrayContext(ABC):
 
     @memoize_method
     def _get_einsum_prg(self,
-                        spec: str, arg_names: Tuple[str, ...],
+                        spec: str, arg_names: tuple[str, ...],
                         tagged: ToTagSetConvertible) -> "loopy.TranslationUnit":
         import loopy as lp
         from loopy.version import MOST_RECENT_LANGUAGE_VERSION
@@ -454,7 +444,7 @@ class ArrayContext(ABC):
     # [1] https://github.com/inducer/meshmode/issues/177
     def einsum(self,
                spec: str, *args: Array,
-               arg_names: Optional[Tuple[str, ...]] = None,
+               arg_names: tuple[str, ...] | None = None,
                tagged: ToTagSetConvertible = ()) -> Array:
         """Computes the result of Einstein summation following the
         convention in :func:`numpy.einsum`.
diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py
index 26cb9db..0b6cd72 100644
--- a/arraycontext/impl/jax/__init__.py
+++ b/arraycontext/impl/jax/__init__.py
@@ -27,7 +27,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Callable, Optional, Tuple
+from collections.abc import Callable
 
 import numpy as np
 
@@ -63,8 +63,8 @@ class EagerJAXArrayContext(ArrayContext):
 
     def _rec_map_container(
             self, func: Callable[[Array], Array], array: ArrayOrContainer,
-            allowed_types: Optional[Tuple[type, ...]] = None, *,
-            default_scalar: Optional[ScalarLike] = None,
+            allowed_types: tuple[type, ...] | None = None, *,
+            default_scalar: ScalarLike | None = None,
             strict: bool = False) -> ArrayOrContainer:
         if allowed_types is None:
             allowed_types = self.array_types
diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py
index b305717..8517ab6 100644
--- a/arraycontext/impl/numpy/fake_numpy.py
+++ b/arraycontext/impl/numpy/fake_numpy.py
@@ -21,7 +21,9 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
+
 from functools import partial, reduce
+from typing import cast
 
 import numpy as np
 
@@ -143,10 +145,9 @@ class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace):
         else:
             if len(serialized_x) != len(serialized_y):
                 return false_ary
-            return reduce(
-                    np.logical_and,
+            return np.logical_and.reduce(
                     [(true_ary if kx_i == ky_i else false_ary)
-                        and self.array_equal(x_i, y_i)
+                        and cast(np.ndarray, self.array_equal(x_i, y_i))
                         for (kx_i, x_i), (ky_i, y_i)
                         in zip(serialized_x, serialized_y)],
                     true_ary)
diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index 60c001a..84d5f48 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -31,7 +31,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
+from collections.abc import Callable
+from typing import TYPE_CHECKING
 from warnings import warn
 
 import numpy as np
@@ -82,9 +83,9 @@ class PyOpenCLArrayContext(ArrayContext):
 
     def __init__(self,
             queue: pyopencl.CommandQueue,
-            allocator: Optional[pyopencl.tools.AllocatorBase] = None,
-            wait_event_queue_length: Optional[int] = None,
-            force_device_scalars: Optional[bool] = None) -> None:
+            allocator: pyopencl.tools.AllocatorBase | None = None,
+            wait_event_queue_length: int | None = None,
+            force_device_scalars: bool | None = None) -> None:
         r"""
         :arg wait_event_queue_length: The length of a queue of
             :class:`~pyopencl.Event` objects that are maintained by the
@@ -132,7 +133,7 @@ class PyOpenCLArrayContext(ArrayContext):
         self._passed_force_device_scalars = force_device_scalars is not None
 
         self._wait_event_queue_length = wait_event_queue_length
-        self._kernel_name_to_wait_event_queue: Dict[str, List[cl.Event]] = {}
+        self._kernel_name_to_wait_event_queue: dict[str, list[cl.Event]] = {}
 
         if queue.device.type & cl.device_type.GPU:
             if allocator is None:
@@ -150,7 +151,7 @@ class PyOpenCLArrayContext(ArrayContext):
                         stacklevel=2)
 
         self._loopy_transform_cache: \
-                Dict[lp.TranslationUnit, lp.TranslationUnit] = {}
+                dict[lp.TranslationUnit, lp.TranslationUnit] = {}
 
         # TODO: Ideally this should only be `(TaggableCLArray,)`, but
         # that would break the logic in the downstream users.
@@ -162,8 +163,8 @@ class PyOpenCLArrayContext(ArrayContext):
 
     def _rec_map_container(
             self, func: Callable[[Array], Array], array: ArrayOrContainer,
-            allowed_types: Optional[Tuple[type, ...]] = None, *,
-            default_scalar: Optional[ScalarLike] = None,
+            allowed_types: tuple[type, ...] | None = None, *,
+            default_scalar: ScalarLike | None = None,
             strict: bool = False) -> ArrayOrContainer:
         import arraycontext.impl.pyopencl.taggable_cl_array as tga
 
diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py
index a0f3ef4..7de7611 100644
--- a/arraycontext/impl/pyopencl/taggable_cl_array.py
+++ b/arraycontext/impl/pyopencl/taggable_cl_array.py
@@ -6,7 +6,7 @@
 """
 
 from dataclasses import dataclass
-from typing import Any, Dict, FrozenSet, Optional, Tuple
+from typing import Any
 
 import numpy as np
 
@@ -23,19 +23,19 @@ class Axis(Taggable):
     Records the tags corresponding to a dimension of :class:`TaggableCLArray`.
     """
 
-    tags: FrozenSet[Tag]
+    tags: frozenset[Tag]
 
-    def _with_new_tags(self, tags: FrozenSet[Tag]) -> "Axis":
+    def _with_new_tags(self, tags: frozenset[Tag]) -> "Axis":
         from dataclasses import replace
         return replace(self, tags=tags)
 
 
 @memoize
-def _construct_untagged_axes(ndim: int) -> Tuple[Axis, ...]:
+def _construct_untagged_axes(ndim: int) -> tuple[Axis, ...]:
     return tuple(Axis(frozenset()) for _ in range(ndim))
 
 
-def _unwrap_cl_array(ary: cla.Array) -> Dict[str, Any]:
+def _unwrap_cl_array(ary: cla.Array) -> dict[str, Any]:
     return {
         "shape": ary.shape,
         "dtype": ary.dtype,
@@ -109,7 +109,7 @@ class TaggableCLArray(cla.Array, Taggable):
         return type(self)(None, tags=self.tags, axes=self.axes,
                           **_unwrap_cl_array(ary))
 
-    def _with_new_tags(self, tags: FrozenSet[Tag]) -> "TaggableCLArray":
+    def _with_new_tags(self, tags: frozenset[Tag]) -> "TaggableCLArray":
         return type(self)(None, tags=tags, axes=self.axes,
                           **_unwrap_cl_array(self))
 
@@ -127,8 +127,8 @@ class TaggableCLArray(cla.Array, Taggable):
 
 
 def to_tagged_cl_array(ary: cla.Array,
-                       axes: Optional[Tuple[Axis, ...]] = None,
-                       tags: FrozenSet[Tag] = frozenset()) -> TaggableCLArray:
+                       axes: tuple[Axis, ...] | None = None,
+                       tags: frozenset[Tag] = frozenset()) -> TaggableCLArray:
     """
     Returns a :class:`TaggableCLArray` that is constructed from the data in
     *ary* along with the metadata from *axes* and *tags*. If *ary* is already a
@@ -167,8 +167,8 @@ def to_tagged_cl_array(ary: cla.Array,
 # {{{ creation
 
 def empty(queue, shape, dtype=float, *,
-        axes: Optional[Tuple[Axis, ...]] = None,
-        tags: FrozenSet[Tag] = frozenset(),
+        axes: tuple[Axis, ...] | None = None,
+        tags: frozenset[Tag] = frozenset(),
         order: str = "C",
         allocator=None) -> TaggableCLArray:
     if dtype is not None:
@@ -181,8 +181,8 @@ def empty(queue, shape, dtype=float, *,
 
 
 def zeros(queue, shape, dtype=float, *,
-        axes: Optional[Tuple[Axis, ...]] = None,
-        tags: FrozenSet[Tag] = frozenset(),
+        axes: tuple[Axis, ...] | None = None,
+        tags: frozenset[Tag] = frozenset(),
         order: str = "C",
         allocator=None) -> TaggableCLArray:
     result = empty(
@@ -194,8 +194,8 @@ def zeros(queue, shape, dtype=float, *,
 
 
 def to_device(queue, ary, *,
-        axes: Optional[Tuple[Axis, ...]] = None,
-        tags: FrozenSet[Tag] = frozenset(),
+        axes: tuple[Axis, ...] | None = None,
+        tags: frozenset[Tag] = frozenset(),
         allocator=None):
     return to_tagged_cl_array(
         cla.to_device(queue, ary, allocator=allocator),
diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 099738a..e3c830e 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -47,17 +47,8 @@ THE SOFTWARE.
 
 import abc
 import sys
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Callable,
-    Dict,
-    FrozenSet,
-    Optional,
-    Tuple,
-    Type,
-    Union,
-)
+from collections.abc import Callable
+from typing import TYPE_CHECKING, Any
 
 import numpy as np
 
@@ -91,7 +82,7 @@ logger = logging.getLogger(__name__)
 
 # {{{ tag conversion
 
-def _preprocess_array_tags(tags: ToTagSetConvertible) -> FrozenSet[Tag]:
+def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]:
     tags = normalize_tags(tags)
 
     name_hints = [tag for tag in tags if isinstance(tag, NameHint)]
@@ -135,7 +126,7 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC):
 
     def __init__(
             self, *,
-            compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None
+            compile_trace_callback: Callable[[Any, str, Any], None] | None = None
             ) -> None:
         """
         :arg compile_trace_callback: A function of three arguments
@@ -148,10 +139,10 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC):
         super().__init__()
 
         import pytato as pt
-        self._freeze_prg_cache: Dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
-        self._dag_transform_cache: Dict[
+        self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
+        self._dag_transform_cache: dict[
                 pt.DictOfNamedArrays,
-                Tuple[pt.DictOfNamedArrays, str]] = {}
+                tuple[pt.DictOfNamedArrays, str]] = {}
 
         if compile_trace_callback is None:
             def _compile_trace_callback(what, stage, ir):
@@ -166,7 +157,7 @@ class _BasePytatoArrayContext(ArrayContext, abc.ABC):
         return PytatoFakeNumpyNamespace(self)
 
     @abc.abstractproperty
-    def _frozen_array_types(self) -> Tuple[Type, ...]:
+    def _frozen_array_types(self) -> tuple[type, ...]:
         """
         Returns valid frozen array types for the array context.
         """
@@ -256,11 +247,11 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
     """
     def __init__(
             self, queue: cl.CommandQueue, allocator=None, *,
-            use_memory_pool: Optional[bool] = None,
-            compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None,
+            use_memory_pool: bool | None = None,
+            compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
 
             # do not use: only for testing
-            _force_svm_arg_limit: Optional[int] = None,
+            _force_svm_arg_limit: int | None = None,
             ) -> None:
         """
         :arg compile_trace_callback: A function of three arguments
@@ -322,14 +313,14 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
         self._force_svm_arg_limit = _force_svm_arg_limit
 
     @property
-    def _frozen_array_types(self) -> Tuple[Type, ...]:
+    def _frozen_array_types(self) -> tuple[type, ...]:
         import pyopencl.array as cla
         return (cla.Array,)
 
     def _rec_map_container(
             self, func: Callable[[Array], Array], array: ArrayOrContainer,
-            allowed_types: Optional[Tuple[type, ...]] = None, *,
-            default_scalar: Optional[ScalarLike] = None,
+            allowed_types: tuple[type, ...] | None = None, *,
+            default_scalar: ScalarLike | None = None,
             strict: bool = False) -> ArrayOrContainer:
         import pytato as pt
 
@@ -452,13 +443,13 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
             get_cl_axes_from_pt_axes,
         )
 
-        array_as_dict: Dict[str, Union[cla.Array, TaggableCLArray, pt.Array]] = {}
-        key_to_frozen_subary: Dict[str, TaggableCLArray] = {}
-        key_to_pt_arrays: Dict[str, pt.Array] = {}
+        array_as_dict: dict[str, cla.Array | TaggableCLArray | pt.Array] = {}
+        key_to_frozen_subary: dict[str, TaggableCLArray] = {}
+        key_to_pt_arrays: dict[str, pt.Array] = {}
 
         def _record_leaf_ary_in_dict(
-                key: Tuple[Any, ...],
-                ary: Union[cla.Array, TaggableCLArray, pt.Array]) -> None:
+                key: tuple[Any, ...],
+                ary: cla.Array | TaggableCLArray | pt.Array) -> None:
             key_str = "_ary" + _ary_container_key_stringifier(key)
             array_as_dict[key_str] = ary
 
@@ -498,7 +489,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
 
         # }}}
 
-        def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray:
+        def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
             key_str = "_ary" + _ary_container_key_stringifier(key)
             return key_to_frozen_subary[key_str]
 
@@ -706,8 +697,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
     """
 
     def __init__(self,
-            *, compile_trace_callback: Optional[Callable[[Any, str, Any], None]]
-             = None) -> None:
+            *,
+            compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
+            ) -> None:
         """
         :arg compile_trace_callback: A function of three arguments
             *(what, stage, ir)*, where *what* identifies the object
@@ -722,14 +714,14 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
         self.array_types = (pt.Array, jnp.ndarray)
 
     @property
-    def _frozen_array_types(self) -> Tuple[Type, ...]:
+    def _frozen_array_types(self) -> tuple[type, ...]:
         import jax.numpy as jnp
         return (jnp.ndarray, )
 
     def _rec_map_container(
             self, func: Callable[[Array], Array], array: ArrayOrContainer,
-            allowed_types: Optional[Tuple[type, ...]] = None, *,
-            default_scalar: Optional[ScalarLike] = None,
+            allowed_types: tuple[type, ...] | None = None, *,
+            default_scalar: ScalarLike | None = None,
             strict: bool = False) -> ArrayOrContainer:
         if allowed_types is None:
             allowed_types = self.array_types
@@ -783,12 +775,12 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
         from arraycontext.container.traversal import rec_keyed_map_array_container
         from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
 
-        array_as_dict: Dict[str, Union[jnp.ndarray, pt.Array]] = {}
-        key_to_frozen_subary: Dict[str, jnp.ndarray] = {}
-        key_to_pt_arrays: Dict[str, pt.Array] = {}
+        array_as_dict: dict[str, jnp.ndarray | pt.Array] = {}
+        key_to_frozen_subary: dict[str, jnp.ndarray] = {}
+        key_to_pt_arrays: dict[str, pt.Array] = {}
 
-        def _record_leaf_ary_in_dict(key: Tuple[Any, ...],
-                                     ary: Union[jnp.ndarray, pt.Array]) -> None:
+        def _record_leaf_ary_in_dict(key: tuple[Any, ...],
+                                     ary: jnp.ndarray | pt.Array) -> None:
             key_str = "_ary" + _ary_container_key_stringifier(key)
             array_as_dict[key_str] = ary
 
@@ -812,7 +804,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
 
         # }}}
 
-        def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray:
+        def _to_frozen(key: tuple[Any, ...], ary) -> jnp.ndarray:
             key_str = "_ary" + _ary_container_key_stringifier(key)
             return key_to_frozen_subary[key_str]
 
diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py
index 54d2cbb..952761b 100644
--- a/arraycontext/impl/pytato/compile.py
+++ b/arraycontext/impl/pytato/compile.py
@@ -32,8 +32,9 @@ THE SOFTWARE.
 import abc
 import itertools
 import logging
+from collections.abc import Callable, Hashable, Mapping
 from dataclasses import dataclass, field
-from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple, Type
+from typing import Any
 
 import numpy as np
 from immutabledict import immutabledict
@@ -106,7 +107,7 @@ class LeafArrayDescriptor(AbstractInputDescriptor):
 
 # {{{ utilities
 
-def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
+def _ary_container_key_stringifier(keys: tuple[Any, ...]) -> str:
     """
     Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an
     array-container's component's key. Goals of this routine:
@@ -116,7 +117,7 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
     * (informal) Shorter identifiers are preferred
     """
     def _rec_str(key: Any) -> str:
-        if isinstance(key, (str, int)):
+        if isinstance(key, str | int):
             return str(key)
         elif isinstance(key, tuple):
             # t in '_actx_t': stands for tuple
@@ -128,11 +129,11 @@ def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
     return "_".join(_rec_str(key) for key in keys)
 
 
-def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
+def _get_arg_id_to_arg_and_arg_id_to_descr(args: tuple[Any, ...],
                                            kwargs: Mapping[str, Any]
                                            ) -> \
-            Tuple[Mapping[Tuple[Any, ...], Any],
-                  Mapping[Tuple[Any, ...], AbstractInputDescriptor]]:
+            tuple[Mapping[tuple[Hashable, ...], Any],
+                  Mapping[tuple[Hashable, ...], AbstractInputDescriptor]]:
     """
     Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts
     mappings from argument id to argument values and from argument id to
@@ -140,8 +141,8 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
     :attr:`CompiledFunction.input_id_to_name_in_program` for argument-id's
     representation.
     """
-    arg_id_to_arg: Dict[Tuple[Any, ...], Any] = {}
-    arg_id_to_descr: Dict[Tuple[Any, ...], AbstractInputDescriptor] = {}
+    arg_id_to_arg: dict[tuple[Hashable, ...], Any] = {}
+    arg_id_to_descr: dict[tuple[Hashable, ...], AbstractInputDescriptor] = {}
 
     for kw, arg in itertools.chain(enumerate(args),
                                    kwargs.items()):
@@ -259,7 +260,7 @@ class BaseLazilyCompilingFunctionCaller:
 
     actx: _BasePytatoArrayContext
     f: Callable[..., Any]
-    program_cache: Dict[Mapping[Tuple[Any, ...], AbstractInputDescriptor],
+    program_cache: dict[Mapping[tuple[Hashable, ...], AbstractInputDescriptor],
                         "CompiledFunction"] = field(default_factory=lambda: {})
 
     # {{{ abstract interface
@@ -269,11 +270,11 @@ class BaseLazilyCompilingFunctionCaller:
 
     @property
     def compiled_function_returning_array_container_class(
-            self) -> Type["CompiledFunction"]:
+            self) -> type["CompiledFunction"]:
         raise NotImplementedError
 
     @property
-    def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]:
+    def compiled_function_returning_array_class(self) -> type["CompiledFunction"]:
         raise NotImplementedError
 
     # }}}
@@ -382,11 +383,11 @@ class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller):
 
     @property
     def compiled_function_returning_array_container_class(
-            self) -> Type["CompiledFunction"]:
+            self) -> type["CompiledFunction"]:
         return CompiledPyOpenCLFunctionReturningArrayContainer
 
     @property
-    def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]:
+    def compiled_function_returning_array_class(self) -> type["CompiledFunction"]:
         return CompiledPyOpenCLFunctionReturningArray
 
     def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
@@ -481,11 +482,11 @@ class LazilyCompilingFunctionCaller(LazilyPyOpenCLCompilingFunctionCaller):
 class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller):
     @property
     def compiled_function_returning_array_container_class(
-            self) -> Type["CompiledFunction"]:
+            self) -> type["CompiledFunction"]:
         return CompiledJAXFunctionReturningArrayContainer
 
     @property
-    def compiled_function_returning_array_class(self) -> Type["CompiledFunction"]:
+    def compiled_function_returning_array_class(self) -> type["CompiledFunction"]:
         return CompiledJAXFunctionReturningArray
 
     def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
@@ -627,10 +628,10 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
     """
     actx: PytatoPyOpenCLArrayContext
     pytato_program: pt.target.BoundProgram
-    input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
-    output_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
-    name_in_program_to_tags: Mapping[str, FrozenSet[Tag]]
-    name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]]
+    input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
+    output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
+    name_in_program_to_tags: Mapping[str, frozenset[Tag]]
+    name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]]
     output_template: ArrayContainer
 
     def __call__(self, arg_id_to_arg) -> ArrayContainer:
@@ -670,9 +671,9 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
     """
     actx: PytatoPyOpenCLArrayContext
     pytato_program: pt.target.BoundProgram
-    input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
-    output_tags: FrozenSet[Tag]
-    output_axes: Tuple[pt.Axis, ...]
+    input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
+    output_tags: frozenset[Tag]
+    output_axes: tuple[pt.Axis, ...]
     output_name: str
 
     def __call__(self, arg_id_to_arg) -> ArrayContainer:
@@ -719,10 +720,10 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
     """
     actx: PytatoJAXArrayContext
     pytato_program: pt.target.BoundProgram
-    input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
-    output_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
-    name_in_program_to_tags: Mapping[str, FrozenSet[Tag]]
-    name_in_program_to_axes: Mapping[str, Tuple[pt.Axis, ...]]
+    input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
+    output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
+    name_in_program_to_tags: Mapping[str, frozenset[Tag]]
+    name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]]
     output_template: ArrayContainer
 
     def __call__(self, arg_id_to_arg) -> ArrayContainer:
@@ -750,9 +751,9 @@ class CompiledJAXFunctionReturningArray(CompiledFunction):
     """
     actx: PytatoJAXArrayContext
     pytato_program: pt.target.BoundProgram
-    input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
-    output_tags: FrozenSet[Tag]
-    output_axes: Tuple[pt.Axis, ...]
+    input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
+    output_tags: frozenset[Tag]
+    output_axes: tuple[pt.Axis, ...]
     output_name: str
 
     def __call__(self, arg_id_to_arg) -> ArrayContainer:
diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py
index a5582d1..affa1da 100644
--- a/arraycontext/impl/pytato/utils.py
+++ b/arraycontext/impl/pytato/utils.py
@@ -22,8 +22,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-
-from typing import TYPE_CHECKING, Any, Dict, Mapping, Set, Tuple
+from collections.abc import Mapping
+from typing import TYPE_CHECKING, Any
 
 from pytato.array import (
     AbstractResultWithNamedArrays,
@@ -54,9 +54,9 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
     """
     def __init__(self) -> None:
         super().__init__()
-        self.bound_arguments: Dict[str, Any] = {}
+        self.bound_arguments: dict[str, Any] = {}
         self.vng = UniqueNameGenerator()
-        self.seen_inputs: Set[str] = set()
+        self.seen_inputs: set[str] = set()
 
     def map_data_wrapper(self, expr: DataWrapper) -> Array:
         if expr.name is not None:
@@ -87,7 +87,7 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
 
 def _normalize_pt_expr(
         expr: DictOfNamedArrays
-        ) -> Tuple[AbstractResultWithNamedArrays, Mapping[str, Any]]:
+        ) -> tuple[AbstractResultWithNamedArrays, Mapping[str, Any]]:
     """
     Returns ``(normalized_expr, bound_arguments)``.  *normalized_expr* is a
     normalized form of *expr*, with all instances of
@@ -102,11 +102,11 @@ def _normalize_pt_expr(
     return normalized_expr, normalize_mapper.bound_arguments
 
 
-def get_pt_axes_from_cl_axes(axes: Tuple[ClAxis, ...]) -> Tuple[PtAxis, ...]:
+def get_pt_axes_from_cl_axes(axes: tuple[ClAxis, ...]) -> tuple[PtAxis, ...]:
     return tuple(PtAxis(axis.tags) for axis in axes)
 
 
-def get_cl_axes_from_pt_axes(axes: Tuple[PtAxis, ...]) -> Tuple[ClAxis, ...]:
+def get_cl_axes_from_pt_axes(axes: tuple[PtAxis, ...]) -> tuple[ClAxis, ...]:
     return tuple(ClAxis(axis.tags) for axis in axes)
 
 
diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py
index 1bee3eb..a62023b 100644
--- a/arraycontext/loopy.py
+++ b/arraycontext/loopy.py
@@ -27,7 +27,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import ClassVar, Mapping
+from collections.abc import Mapping
+from typing import ClassVar
 
 import numpy as np
 
diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py
index c778154..f1f62a7 100644
--- a/arraycontext/pytest.py
+++ b/arraycontext/pytest.py
@@ -31,7 +31,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from typing import Any, Callable, Dict, Sequence, Type, Union
+from collections.abc import Callable, Sequence
+from typing import Any
 
 from arraycontext import NumpyArrayContext
 from arraycontext.context import ArrayContext
@@ -244,19 +245,18 @@ class _PytestNumpyArrayContextFactory(PytestArrayContextFactory):
 # }}}
 
 
-_ARRAY_CONTEXT_FACTORY_REGISTRY: \
-        Dict[str, Type[PytestArrayContextFactory]] = {
-                "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
-                "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
-                "pytato:jax": _PytestPytatoJaxArrayContextFactory,
-                "eagerjax": _PytestEagerJaxArrayContextFactory,
-                "numpy": _PytestNumpyArrayContextFactory,
-                }
+_ARRAY_CONTEXT_FACTORY_REGISTRY: dict[str, type[PytestArrayContextFactory]] = {
+    "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
+    "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
+    "pytato:jax": _PytestPytatoJaxArrayContextFactory,
+    "eagerjax": _PytestEagerJaxArrayContextFactory,
+    "numpy": _PytestNumpyArrayContextFactory,
+    }
 
 
 def register_pytest_array_context_factory(
         name: str,
-        factory: Type[PytestArrayContextFactory]) -> None:
+        factory: type[PytestArrayContextFactory]) -> None:
     if name in _ARRAY_CONTEXT_FACTORY_REGISTRY:
         raise ValueError(f"factory '{name}' already exists")
 
@@ -268,7 +268,7 @@ def register_pytest_array_context_factory(
 # {{{ pytest integration
 
 def pytest_generate_tests_for_array_contexts(
-        factories: Sequence[Union[str, Type[PytestArrayContextFactory]]], *,
+        factories: Sequence[str | type[PytestArrayContextFactory]], *,
         factory_arg_name: str = "actx_factory",
         ) -> Callable[[Any], None]:
     """Parametrize tests for pytest to use an :class:`~arraycontext.ArrayContext`.
diff --git a/arraycontext/version.py b/arraycontext/version.py
index 31baea0..05fe876 100644
--- a/arraycontext/version.py
+++ b/arraycontext/version.py
@@ -1,8 +1,7 @@
 from importlib import metadata
-from typing import Tuple
 
 
-def _parse_version(version: str) -> Tuple[Tuple[int, ...], str]:
+def _parse_version(version: str) -> tuple[tuple[int, ...], str]:
     import re
 
     m = re.match("^([0-9.]+)([a-z0-9]*?)$", VERSION_TEXT)
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 7bea0dc..5cffb20 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -255,7 +255,7 @@ class MyContainerDOFBcast:
 
 def _get_test_containers(actx, ambient_dim=2, shapes=50_000):
     from numbers import Number
-    if isinstance(shapes, (Number, tuple)):
+    if isinstance(shapes, Number | tuple):
         shapes = [shapes]
 
     x = DOFArray(actx, tuple(actx.from_numpy(randn(shape, np.float64))
@@ -1072,7 +1072,7 @@ def test_flatten_array_container(actx_factory, shapes):
 
     # {{{ complex to real
 
-    if isinstance(shapes, (int, tuple)):
+    if isinstance(shapes, int | tuple):
         shapes = [shapes]
 
     ary = DOFArray(actx, tuple(actx.from_numpy(randn(shape, np.float64))
@@ -1556,7 +1556,7 @@ def test_to_numpy_on_frozen_arrays(actx_factory):
 def test_tagging(actx_factory):
     actx = actx_factory()
 
-    if isinstance(actx, (NumpyArrayContext, EagerJAXArrayContext)):
+    if isinstance(actx, NumpyArrayContext | EagerJAXArrayContext):
         pytest.skip(f"{type(actx)} has no tagging support")
 
     from pytools.tag import Tag
-- 
GitLab