diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index ed6dc0c2482290a951258862aa392096a3040c16..76f7647179da16e41e321a1b1687864b15abbf01 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -142,13 +142,13 @@ def _multimap_array_container_impl(
     if len(container_indices) == 1 and reduce_func is None:
         # NOTE: if we just have one ArrayContainer in args, passing it through
         # _map_array_container_impl should be faster
-        def wrapper(ary: ContainerT) -> ContainerT:
+        def wrapper(ary: ArrayOrContainerT) -> ArrayOrContainerT:
             new_args = list(args)
             new_args[container_indices[0]] = ary
             return f(*new_args)
 
         update_wrapper(wrapper, f)
-        template_ary: ContainerT = args[container_indices[0]]
+        template_ary: ArrayOrContainerT = args[container_indices[0]]
         return _map_array_container_impl(
                 wrapper, template_ary,
                 leaf_cls=leaf_cls, recursive=recursive)