From 63b274785c53d4dbd0317e8fe9d522e5e15ec051 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 13 Dec 2022 00:13:01 -0600
Subject: [PATCH] Make GraphT a parametric type alias

---
 pytools/graph.py | 25 +++++++++++++------------
 1 file changed, 13 insertions(+), 12 deletions(-)

diff --git a/pytools/graph.py b/pytools/graph.py
index 0bcfba3..69421d3 100644
--- a/pytools/graph.py
+++ b/pytools/graph.py
@@ -64,7 +64,7 @@ Type Variables Used
 
 from typing import (Collection, Mapping, List, Optional, Any,
                     Callable, Set, MutableSet, Dict, Iterator, Tuple,
-                    Hashable)
+                    Hashable, TypeVar)
 
 try:
     from typing import TypeAlias
@@ -72,15 +72,15 @@ except ImportError:
     from typing_extensions import TypeAlias
 
 
-NodeT: TypeAlias = Hashable
+NodeT = TypeVar("NodeT", bound=Hashable)
 
 
-GraphT: TypeAlias = Mapping[NodeT, Collection[NodeT]]
+GraphT: TypeAlias[NodeT] = Mapping[NodeT, Collection[NodeT]]
 
 
 # {{{ reverse_graph
 
-def reverse_graph(graph: GraphT) -> GraphT:
+def reverse_graph(graph: GraphT[NodeT]) -> GraphT[NodeT]:
     """
     Reverses a graph *graph*.
 
@@ -103,7 +103,7 @@ def reverse_graph(graph: GraphT) -> GraphT:
 # {{{ a_star
 
 def a_star(  # pylint: disable=too-many-locals
-        initial_state: NodeT, goal_state: NodeT, neighbor_map: GraphT,
+        initial_state: NodeT, goal_state: NodeT, neighbor_map: GraphT[NodeT],
         estimate_remaining_cost: Optional[Callable[[NodeT], float]] = None,
         get_step_cost: Callable[[Any, NodeT], float] = lambda x, y: 1
         ) -> List[NodeT]:
@@ -169,7 +169,7 @@ def a_star(  # pylint: disable=too-many-locals
 
 # {{{ compute SCCs with Tarjan's algorithm
 
-def compute_sccs(graph: GraphT) -> List[List[NodeT]]:
+def compute_sccs(graph: GraphT[NodeT]) -> List[List[NodeT]]:
     to_search = set(graph.keys())
     visit_order: Dict[NodeT, int] = {}
     scc_root = {}
@@ -256,7 +256,7 @@ class HeapEntry:
         return self.key < other.key
 
 
-def compute_topological_order(graph: GraphT,
+def compute_topological_order(graph: GraphT[NodeT],
                               key: Optional[Callable[[NodeT], Any]] = None) \
                               -> List[NodeT]:
     """Compute a topological order of nodes in a directed graph.
@@ -327,7 +327,8 @@ def compute_topological_order(graph: GraphT,
 
 # {{{ compute transitive closure
 
-def compute_transitive_closure(graph: Mapping[NodeT, MutableSet[NodeT]]) -> GraphT:
+def compute_transitive_closure(
+        graph: Mapping[NodeT, MutableSet[NodeT]]) -> GraphT[NodeT]:
     """Compute the transitive closure of a directed graph using Warshall's
     algorithm.
 
@@ -362,7 +363,7 @@ def compute_transitive_closure(graph: Mapping[NodeT, MutableSet[NodeT]]) -> Grap
 
 # {{{ check for cycle
 
-def contains_cycle(graph: GraphT) -> bool:
+def contains_cycle(graph: GraphT[NodeT]) -> bool:
     """Determine whether a graph contains a cycle.
 
     :returns: A :class:`bool` indicating whether the graph contains a cycle.
@@ -382,7 +383,7 @@ def contains_cycle(graph: GraphT) -> bool:
 # {{{ compute induced subgraph
 
 def compute_induced_subgraph(graph: Mapping[NodeT, Set[NodeT]],
-                             subgraph_nodes: Set[NodeT]) -> GraphT:
+                             subgraph_nodes: Set[NodeT]) -> GraphT[NodeT]:
     """Compute the induced subgraph formed by a subset of the vertices in a
     graph.
 
@@ -411,7 +412,7 @@ def compute_induced_subgraph(graph: Mapping[NodeT, Set[NodeT]],
 
 # {{{ validate graph
 
-def validate_graph(graph: GraphT) -> None:
+def validate_graph(graph: GraphT[NodeT]) -> None:
     """
     Validates that all successor nodes of each node in *graph* are keys in
     *graph* itself. Raises a :class:`ValueError` if not.
@@ -430,7 +431,7 @@ def validate_graph(graph: GraphT) -> None:
 
 # {{{
 
-def is_connected(graph: GraphT) -> bool:
+def is_connected(graph: GraphT[NodeT]) -> bool:
     """
     Returns whether all nodes in *graph* are connected, ignoring
     the edge direction.
-- 
GitLab