diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index c9b1282d898ae7b3466c19c0d7108d6c4a40bb1e..7f08c1e479e4b9a3d3e9530393d6450903973bcb 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -631,7 +631,10 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
 
     _flatten(ary)
 
-    return actx.np.concatenate(result)
+    if len(result) == 1:
+        return result[0]
+    else:
+        return actx.np.concatenate(result)
 
 
 def unflatten(