diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index b29dc86df66461ab9ca214e347954f3742baae31..2469334de47c47718d3c6d8ce1063de88362232c 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -70,7 +70,7 @@ import numpy as np
 
 from arraycontext.context import ArrayContext, DeviceArray
 from arraycontext.container import (
-        ContainerT, ArrayOrContainerT, NotAnArrayContainerError,
+        ArrayT, ContainerT, ArrayOrContainerT, NotAnArrayContainerError,
         serialize_container, deserialize_container)
 
 
@@ -327,8 +327,11 @@ def multimapped_over_array_containers(
 
 # {{{ keyed array container traversal
 
-def keyed_map_array_container(f: Callable[[Any, Any], Any],
-                              ary: ArrayOrContainerT) -> ArrayOrContainerT:
+def keyed_map_array_container(
+        f: Callable[
+            [Any, ArrayOrContainerT],
+            ArrayOrContainerT],
+        ary: ArrayOrContainerT) -> ArrayOrContainerT:
     r"""Applies *f* to all components of an :class:`ArrayContainer`.
 
     Works similarly to :func:`map_array_container`, but *f* also takes an
@@ -350,8 +353,9 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any],
             ])
 
 
-def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any],
-                                  ary: ArrayOrContainerT) -> ArrayOrContainerT:
+def rec_keyed_map_array_container(
+        f: Callable[[Tuple[Any, ...], ArrayT], ArrayT],
+        ary: ArrayOrContainerT) -> ArrayOrContainerT:
     """
     Works similarly to :func:`rec_map_array_container`, except that *f* also
     takes in a traversal path to the leaf array. The traversal path argument is