diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 7f08c1e479e4b9a3d3e9530393d6450903973bcb..145ef63503ff105023534bced05aca763f96c23b 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -583,7 +583,10 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
 
 # {{{ flatten / unflatten
 
-def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
+def flatten(
+        ary: ArrayOrContainerT, actx: ArrayContext, *,
+        leaf_class: Optional[type] = None,
+        ) -> Any:
     """Convert all arrays in the :class:`~arraycontext.ArrayContainer`
     into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`.
 
@@ -591,11 +594,20 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
     ``ravel`` and ``concatenate`` methods implemented. The order in which the
     individual leaf arrays appear in the final array is dependent on the order
     given by :func:`~arraycontext.serialize_container`.
+
+    If *leaf_class* is given, then :func:`unflatten` will not be able to recover
+    the original *ary*.
+
+    :arg leaf_class: an :class:`~arraycontext.ArrayContainer` class on which
+        the recursion is stopped (subclasses are not considered). If given, only
+        the entries of this type are flattened and the rest of the tree
+        structure is left as is. By default, the recursion is stopped when
+        a non-:class:`~arraycontext.ArrayContainer` is found, which results in
+        the whole input container *ary* being flattened.
     """
     common_dtype = None
-    result: List[Any] = []
 
-    def _flatten(subary: ArrayOrContainerT) -> None:
+    def _flatten(subary: ArrayOrContainerT) -> List[Any]:
         nonlocal common_dtype
 
         try:
@@ -624,17 +636,40 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
                         "This functionality needs to be implemented by the "
                         "array context.") from exc
 
-            result.append(flat_subary)
+            result = [flat_subary]
         else:
+            result = []
             for _, isubary in iterable:
-                _flatten(isubary)
+                result.extend(_flatten(isubary))
+
+        return result
+
+    def _flatten_without_leaf_class(subary: ArrayOrContainerT) -> Any:
+        result = _flatten(subary)
+
+        if len(result) == 1:
+            return result[0]
+        else:
+            return actx.np.concatenate(result)
+
+    def _flatten_with_leaf_class(subary: ArrayOrContainerT) -> Any:
+        if type(subary) is leaf_class:
+            return _flatten_without_leaf_class(subary)
 
-    _flatten(ary)
+        try:
+            iterable = serialize_container(subary)
+        except NotAnArrayContainerError:
+            return subary
+        else:
+            return deserialize_container(subary, [
+                (key, _flatten_with_leaf_class(isubary))
+                for key, isubary in iterable
+                ])
 
-    if len(result) == 1:
-        return result[0]
+    if leaf_class is None:
+        return _flatten_without_leaf_class(ary)
     else:
-        return actx.np.concatenate(result)
+        return _flatten_with_leaf_class(ary)
 
 
 def unflatten(
@@ -647,6 +682,8 @@ def unflatten(
     The order and sizes of each slice into *ary* are determined by the
     array container *template*.
 
+    :arg ary: a flat one-dimensional array with a size that matches the
+        number of entries in *template*.
     :arg strict: if *True* additional :class:`~numpy.dtype` and stride
         checking is performed on the unflattened array. Otherwise, these
         checks are skipped.
diff --git a/run-pylint.sh b/run-pylint.sh
index 6c7386669da645c3349bfb09355b80d0573f7273..c103b82edb447389387d16db7aec22a14524f5e9 100755
--- a/run-pylint.sh
+++ b/run-pylint.sh
@@ -20,4 +20,4 @@ if [[ -f .pylintrc-local.yml ]]; then
     PYLINT_RUNNER_ARGS+=" --yaml-rcfile=.pylintrc-local.yml"
 fi
 
-python .run-pylint.py $PYLINT_RUNNER_ARGS $(basename $PWD) test/test_*.py examples "$@"
+PYTHONWARNINGS=ignore python .run-pylint.py $PYLINT_RUNNER_ARGS $(basename $PWD) test/test_*.py examples "$@"
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index ac7ebbf5798c232cc5ce8aa8bb4b382b98112d46..97f853e036f81b724b3e2e4130cf65acaf6a08b0 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -148,6 +148,10 @@ class DOFArray:
         # Why tuple([...])? https://stackoverflow.com/a/48592299
         return (f"{template_instance_name}.array_context, tuple([{arg}])")
 
+    @property
+    def size(self):
+        return sum(ary.size for ary in self.data)
+
     @property
     def real(self):
         return DOFArray(self.array_context, tuple([subary.real for subary in self]))
@@ -1064,6 +1068,30 @@ def test_flatten_array_container_failure(actx_factory):
         # cannot unflatten partially
         unflatten(ary, flat_ary[:-1], actx)
 
+
+def test_flatten_with_leaf_class(actx_factory):
+    actx = actx_factory()
+
+    from arraycontext import flatten
+    arys = _get_test_containers(actx, shapes=512)
+
+    flat = flatten(arys[0], actx, leaf_class=DOFArray)
+    assert isinstance(flat, actx.array_types)
+    assert flat.shape == (arys[0].size,)
+
+    flat = flatten(arys[1], actx, leaf_class=DOFArray)
+    assert isinstance(flat, np.ndarray) and flat.dtype == object
+    assert all(isinstance(entry, actx.array_types) for entry in flat)
+    assert all(entry.shape == (arys[0].size,) for entry in flat)
+
+    flat = flatten(arys[3], actx, leaf_class=DOFArray)
+    assert isinstance(flat, MyContainer)
+    assert isinstance(flat.mass, actx.array_types)
+    assert flat.mass.shape == (arys[3].mass.size,)
+    assert isinstance(flat.enthalpy, actx.array_types)
+    assert flat.enthalpy.shape == (arys[3].enthalpy.size,)
+    assert all(isinstance(entry, actx.array_types) for entry in flat.momentum)
+
 # }}}