diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index c9b1282d898ae7b3466c19c0d7108d6c4a40bb1e..7f08c1e479e4b9a3d3e9530393d6450903973bcb 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -631,7 +631,10 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: _flatten(ary) - return actx.np.concatenate(result) + if len(result) == 1: + return result[0] + else: + return actx.np.concatenate(result) def unflatten(