From 63b274785c53d4dbd0317e8fe9d522e5e15ec051 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 13 Dec 2022 00:13:01 -0600 Subject: [PATCH] Make GraphT a parametric type alias --- pytools/graph.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/pytools/graph.py b/pytools/graph.py index 0bcfba3..69421d3 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -64,7 +64,7 @@ Type Variables Used from typing import (Collection, Mapping, List, Optional, Any, Callable, Set, MutableSet, Dict, Iterator, Tuple, - Hashable) + Hashable, TypeVar) try: from typing import TypeAlias @@ -72,15 +72,15 @@ except ImportError: from typing_extensions import TypeAlias -NodeT: TypeAlias = Hashable +NodeT = TypeVar("NodeT", bound=Hashable) -GraphT: TypeAlias = Mapping[NodeT, Collection[NodeT]] +GraphT: TypeAlias[NodeT] = Mapping[NodeT, Collection[NodeT]] # {{{ reverse_graph -def reverse_graph(graph: GraphT) -> GraphT: +def reverse_graph(graph: GraphT[NodeT]) -> GraphT[NodeT]: """ Reverses a graph *graph*. @@ -103,7 +103,7 @@ def reverse_graph(graph: GraphT) -> GraphT: # {{{ a_star def a_star( # pylint: disable=too-many-locals - initial_state: NodeT, goal_state: NodeT, neighbor_map: GraphT, + initial_state: NodeT, goal_state: NodeT, neighbor_map: GraphT[NodeT], estimate_remaining_cost: Optional[Callable[[NodeT], float]] = None, get_step_cost: Callable[[Any, NodeT], float] = lambda x, y: 1 ) -> List[NodeT]: @@ -169,7 +169,7 @@ def a_star( # pylint: disable=too-many-locals # {{{ compute SCCs with Tarjan's algorithm -def compute_sccs(graph: GraphT) -> List[List[NodeT]]: +def compute_sccs(graph: GraphT[NodeT]) -> List[List[NodeT]]: to_search = set(graph.keys()) visit_order: Dict[NodeT, int] = {} scc_root = {} @@ -256,7 +256,7 @@ class HeapEntry: return self.key < other.key -def compute_topological_order(graph: GraphT, +def compute_topological_order(graph: GraphT[NodeT], key: Optional[Callable[[NodeT], Any]] = None) \ -> List[NodeT]: """Compute a topological order of nodes in a directed graph. @@ -327,7 +327,8 @@ def compute_topological_order(graph: GraphT, # {{{ compute transitive closure -def compute_transitive_closure(graph: Mapping[NodeT, MutableSet[NodeT]]) -> GraphT: +def compute_transitive_closure( + graph: Mapping[NodeT, MutableSet[NodeT]]) -> GraphT[NodeT]: """Compute the transitive closure of a directed graph using Warshall's algorithm. @@ -362,7 +363,7 @@ def compute_transitive_closure(graph: Mapping[NodeT, MutableSet[NodeT]]) -> Grap # {{{ check for cycle -def contains_cycle(graph: GraphT) -> bool: +def contains_cycle(graph: GraphT[NodeT]) -> bool: """Determine whether a graph contains a cycle. :returns: A :class:`bool` indicating whether the graph contains a cycle. @@ -382,7 +383,7 @@ def contains_cycle(graph: GraphT) -> bool: # {{{ compute induced subgraph def compute_induced_subgraph(graph: Mapping[NodeT, Set[NodeT]], - subgraph_nodes: Set[NodeT]) -> GraphT: + subgraph_nodes: Set[NodeT]) -> GraphT[NodeT]: """Compute the induced subgraph formed by a subset of the vertices in a graph. @@ -411,7 +412,7 @@ def compute_induced_subgraph(graph: Mapping[NodeT, Set[NodeT]], # {{{ validate graph -def validate_graph(graph: GraphT) -> None: +def validate_graph(graph: GraphT[NodeT]) -> None: """ Validates that all successor nodes of each node in *graph* are keys in *graph* itself. Raises a :class:`ValueError` if not. @@ -430,7 +431,7 @@ def validate_graph(graph: GraphT) -> None: # {{{ -def is_connected(graph: GraphT) -> bool: +def is_connected(graph: GraphT[NodeT]) -> bool: """ Returns whether all nodes in *graph* are connected, ignoring the edge direction. -- GitLab