diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index c339609c98cde40599bb90555339f3937cc45a65..3eb3664f4c9be355424d2fb4ef0963d9d4b91337 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -603,29 +603,51 @@ def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
 
 def unflatten(
         template: ArrayOrContainerT, ary: Any,
-        actx: ArrayContext) -> ArrayOrContainerT:
+        actx: ArrayContext, *,
+        strict: bool = True) -> ArrayOrContainerT:
     """Unflatten an array *ary* produced by :func:`flatten` back into an
     :class:`~arraycontext.ArrayContainer`.
 
     The order and sizes of each slice into *ary* are determined by the
     array container *template*.
+
+    :arg strict: if *True* additional :class:`~numpy.dtype` and stride
+        checking is performed on the unflattened array. Otherwise, these
+        checks are skipped.
     """
     # NOTE: https://github.com/python/mypy/issues/7057
     offset = 0
+    common_dtype = None
 
     def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:
-        nonlocal offset
+        nonlocal offset, common_dtype
 
         try:
             iterable = serialize_container(template_subary)
         except NotAnArrayContainerError:
+            # {{{ validate subary
+
             if (offset + template_subary.size) > ary.size:
                 raise ValueError("'template' and 'ary' sizes do not match: "
                     "'template' is too large")
 
-            if template_subary.dtype != ary.dtype:
-                raise ValueError("'template' dtype does not match 'ary': "
-                        f"got {template_subary.dtype}, expected {ary.dtype}")
+            if strict:
+                if template_subary.dtype != ary.dtype:
+                    raise ValueError("'template' dtype does not match 'ary': "
+                            f"got {template_subary.dtype}, expected {ary.dtype}")
+            else:
+                # NOTE: still require that *template* has a uniform dtype
+                if common_dtype is None:
+                    common_dtype = template_subary.dtype
+                else:
+                    if common_dtype != template_subary.dtype:
+                        raise ValueError("arrays in 'template' have different "
+                                f"dtypes: got {template_subary.dtype}, but "
+                                f"expected {common_dtype}.")
+
+            # }}}
+
+            # {{{ reshape
 
             flat_subary = ary[offset:offset + template_subary.size]
             try:
@@ -640,12 +662,18 @@ def unflatten(
                         "This functionality needs to be implemented by the "
                         "array context.") from exc
 
-            if hasattr(template_subary, "strides"):
+            # }}}
+
+            # {{{ check strides
+
+            if strict and hasattr(template_subary, "strides"):
                 if template_subary.strides != subary.strides:
                     raise ValueError(
                             f"strides do not match template: got {subary.strides}, "
                             f"expected {template_subary.strides}")
 
+            # }}}
+
             offset += template_subary.size
             return subary
         else:
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index 632c95bd27c692bb6ac6c56f67f1d6d021b6b5aa..d97f00f0f1dd1a72c997109e28914376c917652d 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -926,6 +926,28 @@ def test_flatten_array_container(actx_factory, shapes):
                 actx.np.linalg.norm(ary - ary_roundtrip)
                 ) < 1.0e-15
 
+    # {{{ complex to real
+
+    if isinstance(shapes, (int, tuple)):
+        shapes = [shapes]
+
+    ary = DOFArray(actx, tuple([
+        actx.from_numpy(randn(shape, np.float64))
+        for shape in shapes]))
+
+    template = DOFArray(actx, tuple([
+        actx.from_numpy(randn(shape, np.complex128))
+        for shape in shapes]))
+
+    flat = flatten(ary, actx)
+    ary_roundtrip = unflatten(template, flat, actx, strict=False)
+
+    assert actx.to_numpy(
+            actx.np.linalg.norm(ary - ary_roundtrip)
+            ) < 1.0e-15
+
+    # }}}
+
 
 def test_flatten_array_container_failure(actx_factory):
     actx = actx_factory()