diff --git a/pytools/graph.py b/pytools/graph.py
index c9ce84aaf511e819b062ed20fec259b1182daab6..095d76a43d5f923e3215662f716a801798283626 100644
--- a/pytools/graph.py
+++ b/pytools/graph.py
@@ -38,8 +38,21 @@ Graph Algorithms
 .. autofunction:: compute_transitive_closure
 .. autofunction:: contains_cycle
 .. autofunction:: compute_induced_subgraph
+
+Type Variables Used
+-------------------
+
+.. class:: T
+
+    Any type.
 """
 
+from typing import (TypeVar, Mapping, Iterable, List, Optional, Any, Callable,
+                    Set, MutableSet, Dict, Iterator, Tuple)
+
+
+T = TypeVar("T")
+
 
 # {{{ a_star
 
@@ -109,19 +122,21 @@ def a_star(  # pylint: disable=too-many-locals
 
 # {{{ compute SCCs with Tarjan's algorithm
 
-def compute_sccs(graph):
+def compute_sccs(graph: Mapping[T, Iterable[T]]) -> List[List[T]]:
     to_search = set(graph.keys())
-    visit_order = {}
+    visit_order: Dict[T, int] = {}
     scc_root = {}
     sccs = []
 
     while to_search:
         top = next(iter(to_search))
-        call_stack = [(top, iter(graph[top]), None)]
+        call_stack: List[Tuple[T, Iterator[T], Optional[T]]] = [(top,
+                                                                 iter(graph[top]),
+                                                                 None)]
         visit_stack = []
         visiting = set()
 
-        scc = []
+        scc: List[T] = []
 
         while call_stack:
             top, children, last_popped_child = call_stack.pop()
@@ -194,7 +209,8 @@ class HeapEntry:
         return self.key < other.key
 
 
-def compute_topological_order(graph, key=None):
+def compute_topological_order(graph: Mapping[T, Iterable[T]],
+                              key: Optional[Callable[[T], Any]] = None) -> List[T]:
     """Compute a topological order of nodes in a directed graph.
 
     :arg graph: A :class:`collections.abc.Mapping` representing a directed
@@ -216,10 +232,8 @@ def compute_topological_order(graph, key=None):
 
     .. versionadded:: 2020.2
     """
-    if key is None:
-        def key(x):
-            # all nodes have the same keys when not provided
-            return 0
+    # all nodes have the same keys when not provided
+    keyfunc = key if key is not None else (lambda x: 0)
 
     from heapq import heapify, heappop, heappush
 
@@ -241,7 +255,7 @@ def compute_topological_order(graph, key=None):
     # heap: list of instances of HeapEntry(n) where 'n' is a node in
     # 'graph' with no predecessor. Nodes with no predecessors are the
     # schedulable candidates.
-    heap = [HeapEntry(n, key(n))
+    heap = [HeapEntry(n, keyfunc(n))
             for n, num_preds in nodes_to_num_predecessors.items()
             if num_preds == 0]
     heapify(heap)
@@ -256,7 +270,7 @@ def compute_topological_order(graph, key=None):
         for child in graph.get(node_to_be_scheduled, ()):
             nodes_to_num_predecessors[child] -= 1
             if nodes_to_num_predecessors[child] == 0:
-                heappush(heap, HeapEntry(child, key(child)))
+                heappush(heap, HeapEntry(child, keyfunc(child)))
 
     if len(order) != total_num_nodes:
         # any node which has a predecessor left is a part of a cycle
@@ -270,7 +284,8 @@ def compute_topological_order(graph, key=None):
 
 # {{{ compute transitive closure
 
-def compute_transitive_closure(graph):
+def compute_transitive_closure(graph: Mapping[T, MutableSet[T]]) -> (
+        Mapping[T, MutableSet[T]]):
     """Compute the transitive closure of a directed graph using Warshall's
         algorithm.
 
@@ -305,7 +320,7 @@ def compute_transitive_closure(graph):
 
 # {{{ check for cycle
 
-def contains_cycle(graph):
+def contains_cycle(graph: Mapping[T, Iterable[T]]) -> bool:
     """Determine whether a graph contains a cycle.
 
     :arg graph: A :class:`collections.abc.Mapping` representing a directed
@@ -329,7 +344,8 @@ def contains_cycle(graph):
 
 # {{{ compute induced subgraph
 
-def compute_induced_subgraph(graph, subgraph_nodes):
+def compute_induced_subgraph(graph: Mapping[T, Set[T]],
+                             subgraph_nodes: Set[T]) -> Mapping[T, Set[T]]:
     """Compute the induced subgraph formed by a subset of the vertices in a
         graph.