diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 2ceb9f92d283c0de95a305faa2443731a53a17e0..03587b51c5f2ce11809586b27b47cb7d649186ea 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -101,14 +101,14 @@ def _multimap_array_container_impl(
     """
     def rec(*_args: Any) -> Any:
         template_ary = _args[container_indices[0]]
-        assert all(
-                type(_args[i]) is type(template_ary) for i in container_indices[1:]
-                ), f"expected type '{type(template_ary).__name__}'"
-
         if (type(template_ary) is leaf_cls
                 or not is_array_container(template_ary)):
             return f(*_args)
 
+        assert all(
+                type(_args[i]) is type(template_ary) for i in container_indices[1:]
+                ), f"expected type '{type(template_ary).__name__}'"
+
         result = []
         new_args = list(_args)