diff --git a/meshmode/dof_array.py b/meshmode/dof_array.py
index 3136484c3290e02b9f87fdabb698d3e006555f6a..12ffe33fb19f83e3a38a10f52d434166ccd93e79 100644
--- a/meshmode/dof_array.py
+++ b/meshmode/dof_array.py
@@ -21,7 +21,7 @@ THE SOFTWARE.
 """
 
 import numpy as np
-from typing import Optional, TYPE_CHECKING
+from typing import Optional, Iterable, TYPE_CHECKING
 from functools import partial
 
 from pytools import single_valued, memoize_in
@@ -53,11 +53,15 @@ class DOFArray(np.ndarray):
     :class:`~meshmode.discretization.ElementGroupBase`.
     The arrays contained within a :class:`DOFArray`
     are expected to be logically two-dimensional, with shape
-    ``(nelements, nunit_dofs)``, using
-    :class:`~meshmode.discretization.ElementGroupBase.nelements`
-    and :class:`~meshmode.discretization.ElementGroupBase.nunit_dofs`.
-    It is derived from :class:`numpy.ndarray` with dtype object ("an object array").
-    The entries in this array are further arrays managed by :attr:`array_context`.
+    ``(nelements, ndofs_per_element)``, where ``nelements`` is the same as
+    :attr:`~meshmode.discretization.ElementGroupBase.nelements`
+    of the associated group.
+    ``ndofs_per_element`` is typically, but not necessarily, the same as
+    :attr:`~meshmode.discretization.ElementGroupBase.nunit_dofs`
+    of the associated group.
+    This array is derived from :class:`numpy.ndarray` with dtype object ("an
+    object array").  The entries in this array are further arrays managed by
+    :attr:`array_context`.
 
     One main purpose of this class is to describe the data structure,
     i.e. when a :class:`DOFArray` occurs inside of further numpy object array,
@@ -105,7 +109,7 @@ class DOFArray(np.ndarray):
         (one per :class:`~meshmode.discretization.ElementGroupBase`).
 
         :arg actx: If *None*, the arrays in *res_list* must be
-        :meth:`~meshmode.array_context.ArrayContext.thaw`\ ed.
+            :meth:`~meshmode.array_context.ArrayContext.thaw`\ ed.
         """
         if not (actx is None or isinstance(actx, ArrayContext)):
             raise TypeError("actx must be of type ArrayContext")
@@ -190,8 +194,9 @@ def flatten(ary: np.ndarray) -> np.ndarray:
     @memoize_in(actx, (flatten, "flatten_prg"))
     def prg():
         return make_loopy_program(
-            "{[iel,idof]: 0<=iel<nelements and 0<=idof<nunit_dofs}",
-            "result[grp_start + iel*nunit_dofs + idof] = grp_ary[iel, idof]",
+            "{[iel,idof]: 0<=iel<nelements and 0<=idof<ndofs_per_element}",
+            """result[grp_start + iel*ndofs_per_element + idof] \
+                = grp_ary[iel, idof]""",
             name="flatten")
 
     result = actx.empty(group_starts[-1], dtype=ary.entry_dtype)
@@ -202,7 +207,8 @@ def flatten(ary: np.ndarray) -> np.ndarray:
     return result
 
 
-def unflatten(actx: ArrayContext, discr: "_Discretization", ary) -> np.ndarray:
+def unflatten(actx: ArrayContext, discr: "_Discretization", ary,
+        ndofs_per_element_per_group: Optional[Iterable[int]] = None) -> np.ndarray:
     r"""Convert a 'flat' array returned by :func:`flatten` back to a :class:`DOFArray`.
 
     Vectorizes over object arrays of :class:`DOFArray`\ s.
@@ -211,17 +217,26 @@ def unflatten(actx: ArrayContext, discr: "_Discretization", ary) -> np.ndarray:
             and ary.dtype.char == "O"
             and not isinstance(ary, DOFArray)):
         return obj_array_vectorize(
-                lambda subary: unflatten(actx, discr, subary),
+                lambda subary: unflatten(
+                    actx, discr, subary, ndofs_per_element_per_group),
                 ary)
 
     @memoize_in(actx, (unflatten, "unflatten_prg"))
     def prg():
         return make_loopy_program(
-            "{[iel,idof]: 0<=iel<nelements and 0<=idof<nunit_dofs}",
-            "result[iel, idof] = ary[grp_start + iel*nunit_dofs + idof]",
+            "{[iel,idof]: 0<=iel<nelements and 0<=idof<ndofs_per_element}",
+            "result[iel, idof] = ary[grp_start + iel*ndofs_per_element + idof]",
             name="unflatten")
 
-    group_sizes = [grp.ndofs for grp in discr.groups]
+    if ndofs_per_element_per_group is None:
+        ndofs_per_element_per_group = [
+                grp.nunit_dofs for grp in discr.groups]
+
+    group_sizes = [
+            grp.nelements * ndofs_per_element
+            for grp, ndofs_per_element
+            in zip(discr.groups, ndofs_per_element_per_group)]
+
     if ary.size != sum(group_sizes):
         raise ValueError("array has size %d, expected %d"
                 % (ary.size, sum(group_sizes)))
@@ -234,9 +249,12 @@ def unflatten(actx: ArrayContext, discr: "_Discretization", ary) -> np.ndarray:
                     prg(),
                     grp_start=grp_start, ary=ary,
                     nelements=grp.nelements,
-                    nunit_dofs=grp.nunit_dofs,
+                    ndofs_per_element=ndofs_per_element,
                     )["result"])
-            for grp_start, grp in zip(group_starts, discr.groups)])
+            for grp_start, grp, ndofs_per_element in zip(
+                    group_starts,
+                    discr.groups,
+                    ndofs_per_element_per_group)])
 
 
 def flat_norm(ary: DOFArray, ord=2):