diff --git a/pytools/graph.py b/pytools/graph.py
index 91f2bffa7a2096d10a46d76b7dbacd22f0b57a4c..baed09958d0c8dc481b6ea447cda8334f0b33716 100644
--- a/pytools/graph.py
+++ b/pytools/graph.py
@@ -5,6 +5,7 @@ __copyright__ = """
 Copyright (C) 2009-2013 Andreas Kloeckner
 Copyright (C) 2020 Matt Wala
 Copyright (C) 2020 James Stevens
+Copyright (C) 2024 Addison Alvey-Blanco
 """
 
 __license__ = """
@@ -47,6 +48,8 @@ Graph Algorithms
 .. autofunction:: as_graphviz_dot
 .. autofunction:: validate_graph
 .. autofunction:: is_connected
+.. autofunction:: undirected_graph_from_edges
+.. autofunction:: get_reachable_nodes
 
 Type Variables Used
 -------------------
@@ -71,13 +74,16 @@ from typing import (
     Callable,
     Collection,
     Dict,
+    FrozenSet,
     Generic,
     Hashable,
+    Iterable,
     Iterator,
     List,
     Mapping,
     MutableSet,
     Optional,
+    Protocol,
     Set,
     Tuple,
     TypeVar,
@@ -98,7 +104,6 @@ else:
 
 NodeT = TypeVar("NodeT", bound=Hashable)
 
-
 GraphT: TypeAlias[NodeT] = Mapping[NodeT, Collection[NodeT]]
 
 
@@ -263,8 +268,13 @@ class CycleError(Exception):
         self.node = node
 
 
+class _SupportsLT(Protocol):
+    def __lt__(self, other: object) -> bool:
+        ...
+
+
 @dataclass(frozen=True)
-class HeapEntry(Generic[NodeT]):
+class _HeapEntry(Generic[NodeT]):
     """
     Helper class to compare associated keys while comparing the elements in
     heap operations.
@@ -273,9 +283,9 @@ class HeapEntry(Generic[NodeT]):
     <https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Lib/heapq.py#L135-L138>.
     """
     node: NodeT
-    key: Any
+    key: _SupportsLT
 
-    def __lt__(self, other: HeapEntry) -> bool:
+    def __lt__(self, other: _HeapEntry[NodeT]) -> bool:
         return self.key < other.key
 
 
@@ -321,7 +331,7 @@ def compute_topological_order(graph: GraphT[NodeT],
     # heap: list of instances of HeapEntry(n) where 'n' is a node in
     # 'graph' with no predecessor. Nodes with no predecessors are the
     # schedulable candidates.
-    heap = [HeapEntry(n, keyfunc(n))
+    heap = [_HeapEntry(n, keyfunc(n))
             for n, num_preds in nodes_to_num_predecessors.items()
             if num_preds == 0]
     heapify(heap)
@@ -336,7 +346,7 @@ def compute_topological_order(graph: GraphT[NodeT],
         for child in graph.get(node_to_be_scheduled, ()):
             nodes_to_num_predecessors[child] -= 1
             if nodes_to_num_predecessors[child] == 0:
-                heappush(heap, HeapEntry(child, keyfunc(child)))
+                heappush(heap, _HeapEntry(child, keyfunc(child)))
 
     if len(order) != total_num_nodes:
         # any node which has a predecessor left is a part of a cycle
@@ -457,11 +467,11 @@ def as_graphviz_dot(graph: GraphT[NodeT],
     from pytools.graphviz import dot_escape
 
     if node_labels is None:
-        def node_labels(x):
+        def node_labels(x: NodeT) -> str:
             return str(x)
 
     if edge_labels is None:
-        def edge_labels(x, y):
+        def edge_labels(x: NodeT, y: NodeT) -> str:
             return ""
 
     node_to_id = {}
@@ -511,7 +521,7 @@ def validate_graph(graph: GraphT[NodeT]) -> None:
 # }}}
 
 
-# {{{
+# {{{ is_connected
 
 def is_connected(graph: GraphT[NodeT]) -> bool:
     """
@@ -542,5 +552,52 @@ def is_connected(graph: GraphT[NodeT]) -> bool:
 
     return visited == graph.keys()
 
+# }}}
+
+
+def undirected_graph_from_edges(
+            edges: Iterable[Tuple[NodeT, NodeT]],
+        ) -> GraphT[NodeT]:
+    """
+    Constructs an undirected graph using *edges*.
+
+    :arg edges: An :class:`Iterable` of pairs of related :class:`NodeT` s.
+
+    :returns: A :class:`GraphT` that is the undirected graph.
+    """
+    undirected_graph: Dict[NodeT, Set[NodeT]] = {}
+
+    for lhs, rhs in edges:
+        if lhs == rhs:
+            raise TypeError("Found loop in edges,"
+                            f" LHS, RHS = {lhs}")
+
+        undirected_graph.setdefault(lhs, set()).add(rhs)
+        undirected_graph.setdefault(rhs, set()).add(lhs)
+
+    return undirected_graph
+
+
+def get_reachable_nodes(
+        undirected_graph: GraphT[NodeT],
+        source_node: NodeT) -> FrozenSet[NodeT]:
+    """
+    Returns a :class:`frozenset` of all nodes in *undirected_graph* that are
+    reachable from *source_node*.
+    """
+    nodes_visited: Set[NodeT] = set()
+    nodes_to_visit = {source_node}
+
+    while nodes_to_visit:
+        current_node = nodes_to_visit.pop()
+        nodes_visited.add(current_node)
+
+        neighbors = undirected_graph[current_node]
+        nodes_to_visit.update({node
+                               for node in neighbors
+                               if node not in nodes_visited})
+
+    return frozenset(nodes_visited)
+
 
 # vim: foldmethod=marker
diff --git a/pytools/test/test_graph_tools.py b/pytools/test/test_graph_tools.py
index a98986edb45771de16b7c882693b212ef07d4061..57c462959964772338694ab2ff8f4ab6fa3e4f8b 100644
--- a/pytools/test/test_graph_tools.py
+++ b/pytools/test/test_graph_tools.py
@@ -431,6 +431,38 @@ def test_is_connected():
     assert is_connected({})
 
 
+def test_propagation_graph_tools():
+    from pytools.graph import (
+        get_reachable_nodes,
+        undirected_graph_from_edges,
+    )
+
+    vars = {"a", "b", "c", "d", "e", "f", "g"}
+
+    constraints = [
+        ("a", "b"),
+        ("a", "d"),
+        ("c", "d"),
+        ("e", "f"),
+        ("f", "g")
+    ]
+
+    all_reachable_nodes = {
+        "a": frozenset({"a", "b", "c", "d"}),
+        "b": frozenset({"a", "b", "c", "d"}),
+        "c": frozenset({"a", "b", "c", "d"}),
+        "e": frozenset({"e", "f", "g"}),
+        "f": frozenset({"e", "f", "g"}),
+        "g": frozenset({"e", "f", "g"})
+    }
+
+    propagation_graph = undirected_graph_from_edges(constraints)
+    assert (
+        all_reachable_nodes[var] == get_reachable_nodes(propagation_graph, var)
+        for var in vars
+    )
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
diff --git a/run-mypy.sh b/run-mypy.sh
index 244d6cc47f454df60e888a80e216a4dd9bf594dc..0220058bb26f5193729a5ac76b7c503754e846d5 100755
--- a/run-mypy.sh
+++ b/run-mypy.sh
@@ -6,5 +6,6 @@ mypy --show-error-codes pytools
 
 mypy --strict --follow-imports=silent \
     pytools/tag.py \
+    pytools/graph.py \
     pytools/datatable.py \
     pytools/persistent_dict.py