diff --git a/pytools/__init__.py b/pytools/__init__.py
index e018f949836cef6420c50ecf75d24e6aed5fe864..5f1203e16ca54bab39519c6c85cd26e1b4b257a0 100644
--- a/pytools/__init__.py
+++ b/pytools/__init__.py
@@ -41,9 +41,15 @@ import math
 from sys import intern
 
 try:
-    from typing import SupportsIndex, ParamSpec, Concatenate
+    from typing import SupportsIndex, Concatenate
 except ImportError:
-    from typing_extensions import SupportsIndex, ParamSpec, Concatenate
+    from typing_extensions import SupportsIndex, Concatenate
+
+try:
+    from typing import ParamSpec
+except ImportError:
+    from typing_extensions import ParamSpec  # type: ignore[assignment]
+
 
 # These are deprecated and will go away in 2022.
 all = builtins.all
diff --git a/pytools/graph.py b/pytools/graph.py
index 09c1b4121f1575d07e4371567d9a0ce1daccc973..1eb9147d40dd697d166e53b0ddac5898705a0719 100644
--- a/pytools/graph.py
+++ b/pytools/graph.py
@@ -37,6 +37,7 @@ Graph Algorithms
 .. autofunction:: compute_transitive_closure
 .. autofunction:: contains_cycle
 .. autofunction:: compute_induced_subgraph
+.. autofunction:: validate_graph
 
 Type Variables Used
 -------------------
@@ -394,4 +395,28 @@ def compute_induced_subgraph(graph: Mapping[T, Set[T]],
 
 # }}}
 
+
+# {{{ validate graph
+
+def validate_graph(graph: Mapping[T, Collection[T]]) -> None:
+    """
+    Validates that all successor nodes of each node in *graph* are keys in
+    *graph* itself. Raises a :class:`ValueError` if not.
+
+    :arg graph: A :class:`collections.abc.Mapping` representing a directed
+        graph. The dictionary contains one key representing each node in the
+        graph, and this key maps to a :class:`collections.abc.Collection` of nodes
+        that are connected to the node by outgoing edges.
+    """
+    seen_nodes: Set[T] = set()
+
+    for children in graph.values():
+        seen_nodes.update(children)
+
+    if not seen_nodes <= graph.keys():
+        raise ValueError(
+            f"invalid graph, missing keys: {seen_nodes-graph.keys()}")
+
+# }}}
+
 # vim: foldmethod=marker
diff --git a/test/test_graph_tools.py b/test/test_graph_tools.py
index 5c7a06e844a6c4bd96e23606f1c837c101067625..4bb9b8ff9b1205fe846adc6821a1525ff027e27e 100644
--- a/test/test_graph_tools.py
+++ b/test/test_graph_tools.py
@@ -312,6 +312,42 @@ def test_reverse_graph():
     assert graph == reverse_graph(reverse_graph(graph))
 
 
+def test_validate_graph():
+    from pytools.graph import validate_graph
+    graph1 = {
+            "d": set("c"),
+            "b": set("a"),
+            "a": set(),
+            "c": set("a"),
+            }
+
+    validate_graph(graph1)
+
+    graph2 = {
+            "d": set("d"),
+            "b": set("c"),
+            "a": set("b"),
+            "c": set("a"),
+            }
+
+    validate_graph(graph2)
+
+    graph3 = {
+        "a": {"b", "c"},
+        "b": {"d", "e"},
+        "c": {"d", "f"},
+        "d": set(),
+        "e": set(),
+        "f": {"g"},
+        "g": {"h", "i", "j"},  # h, i, j missing from keys
+        }
+
+    with pytest.raises(ValueError):
+        validate_graph(graph3)
+
+    validate_graph({})
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])