diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py
index 150d1d6838aae19c369126be75e70216e6f228a3..edbb45061c230d7d34d0e88d95896dd40ef8d6f4 100644
--- a/arraycontext/container/dataclass.py
+++ b/arraycontext/container/dataclass.py
@@ -30,6 +30,13 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from typing import Union, get_args
+try:
+    # NOTE: only available in python >= 3.8
+    from typing import get_origin
+except ImportError:
+    from typing_extensions import get_origin
+
 from dataclasses import fields
 from arraycontext.container import is_array_container_type
 
@@ -50,6 +57,19 @@ def dataclass_array_container(cls: type) -> type:
     assert is_dataclass(cls)
 
     def is_array_field(f: Field) -> bool:
+        from arraycontext import Array
+
+        origin = get_origin(f.type)
+        if origin is Union:
+            if not all(
+                    arg is Array or is_array_container_type(arg)
+                    for arg in get_args(f.type)):
+                raise TypeError(
+                        f"Field '{f.name}' union contains non-array container "
+                        "arguments. All arguments must be array containers.")
+            else:
+                return True
+
         if __debug__:
             if not f.init:
                 raise ValueError(
@@ -61,6 +81,7 @@ def dataclass_array_container(cls: type) -> type:
 
             from typing import _SpecialForm
             if isinstance(f.type, _SpecialForm):
+                # NOTE: anything except a Union is not allowed
                 raise TypeError(
                         f"typing annotation not supported on field '{f.name}': "
                         f"'{f.type!r}'")
@@ -70,7 +91,6 @@ def dataclass_array_container(cls: type) -> type:
                         f"field '{f.name}' not an instance of 'type': "
                         f"'{f.type!r}'")
 
-        from arraycontext import Array
         return f.type is Array or is_array_container_type(f.type)
 
     from pytools import partition
diff --git a/setup.py b/setup.py
index 2bc066ec81fb564ed1093f98231bae87973a3243..eb6421c29d28d2d00ce7e5e259285d178a1f12c3 100644
--- a/setup.py
+++ b/setup.py
@@ -46,7 +46,7 @@ def main():
             "pytest>=2.3",
             "loopy>=2019.1",
             "dataclasses; python_version<'3.7'",
-            "typing_extensions; python_version<'3.8'",
+            "typing_extensions; python_version<'3.9'",
             "types-dataclasses",
         ],
         package_data={"arraycontext": ["py.typed"]},
diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py
index cd61120f0c7a8ca04347d3595ade5d8bf0dd5617..acf099716829400fd079fdf1dc0913c138a8dcd5 100644
--- a/test/test_arraycontext.py
+++ b/test/test_arraycontext.py
@@ -21,6 +21,8 @@ THE SOFTWARE.
 """
 
 from dataclasses import dataclass
+from typing import Union
+
 import numpy as np
 import pytest
 
@@ -678,9 +680,9 @@ def test_array_context_einsum_array_tripleprod(actx_factory, spec):
 @dataclass(frozen=True)
 class MyContainer:
     name: str
-    mass: DOFArray   # or np.ndarray
+    mass: Union[DOFArray, np.ndarray]
     momentum: np.ndarray
-    enthalpy: DOFArray   # or np.ndarray
+    enthalpy: Union[DOFArray, np.ndarray]
 
     @property
     def array_context(self):
@@ -700,9 +702,9 @@ class MyContainer:
 @dataclass(frozen=True)
 class MyContainerDOFBcast:
     name: str
-    mass: DOFArray  # or np.ndarray
+    mass: Union[DOFArray, np.ndarray]
     momentum: np.ndarray
-    enthalpy: DOFArray  # or np.ndarray
+    enthalpy: Union[DOFArray, np.ndarray]
 
     @property
     def array_context(self):
diff --git a/test/test_utils.py b/test/test_utils.py
index ac3127fe8ce7843f5a7c513eb30baa6688e6e73a..7a12ad273c87469809d2d1c6d8626dfd434b8c77 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -106,6 +106,42 @@ def test_dataclass_array_container():
 # }}}
 
 
+# {{{ test_dataclass_container_unions
+
+def test_dataclass_container_unions():
+    from dataclasses import dataclass
+    from arraycontext import dataclass_array_container
+
+    from typing import Union
+    from arraycontext import Array
+
+    # {{{ union fields
+
+    @dataclass
+    class ArrayContainerWithUnion:
+        x: np.ndarray
+        y: Union[np.ndarray, Array]
+
+    dataclass_array_container(ArrayContainerWithUnion)
+
+    # }}}
+
+    # {{{ non-container union
+
+    @dataclass
+    class ArrayContainerWithWrongUnion:
+        x: np.ndarray
+        y: Union[np.ndarray, float]
+
+    with pytest.raises(TypeError):
+        # NOTE: float is not an ArrayContainer, so y should fail
+        dataclass_array_container(ArrayContainerWithWrongUnion)
+
+    # }}}
+
+# }}}
+
+
 if __name__ == "__main__":
     import sys
     if len(sys.argv) > 1: