diff --git a/doc/conf.py b/doc/conf.py index 133844183bdd4eb38fffa5ae56cc79fbccd510b0..659ed9e9b825f6b98d6a302e6b8218ea6a2e66bb 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -37,3 +37,7 @@ nitpick_ignore_regex = [ ] nitpicky = True + +autodoc_type_aliases = {"GraphT": "pytools.graph.GraphT", + "NodeT": "pytools.graph.NodeT", + } diff --git a/pytools/graph.py b/pytools/graph.py index 7f227bc61715a0674626b37593dbd037ecbca5ad..0bcfba31ed96466dc571599bbeaf02457e4b9220 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -1,3 +1,5 @@ +from __future__ import annotations + __copyright__ = """ Copyright (C) 2009-2013 Andreas Kloeckner Copyright (C) 2020 Matt Wala @@ -27,7 +29,11 @@ THE SOFTWARE. __doc__ = """ Graph Algorithms -========================= +================ + +.. note:: + + These functions are mostly geared towards directed graphs (digraphs). .. autofunction:: reverse_graph .. autofunction:: a_star @@ -43,29 +49,44 @@ Graph Algorithms Type Variables Used ------------------- -.. class:: T +.. class:: NodeT + + Type of a graph node, can be any hashable type. + +.. class:: GraphT - Any type. + A :class:`collections.abc.Mapping` representing a directed + graph. The mapping contains one key representing each node in the + graph, and this key maps to a :class:`collections.abc.Collection` of its + successor nodes. Note that most functions expect that every graph node + is included as a key in the graph. """ -from typing import (Collection, TypeVar, Mapping, List, Optional, Any, - Callable, Set, MutableSet, Dict, Iterator, Tuple, FrozenSet) +from typing import (Collection, Mapping, List, Optional, Any, + Callable, Set, MutableSet, Dict, Iterator, Tuple, + Hashable) +try: + from typing import TypeAlias +except ImportError: + from typing_extensions import TypeAlias -T = TypeVar("T") + +NodeT: TypeAlias = Hashable + + +GraphT: TypeAlias = Mapping[NodeT, Collection[NodeT]] # {{{ reverse_graph -def reverse_graph(graph: Mapping[T, Collection[T]]) -> Dict[T, FrozenSet[T]]: +def reverse_graph(graph: GraphT) -> GraphT: """ - Reverses a graph. + Reverses a graph *graph*. - :param graph: A :class:`dict` representation of a directed graph, mapping each - node to other nodes to which it is connected by edges. :returns: A :class:`dict` representing *graph* with edges reversed. """ - result: Dict[T, Set[T]] = {} + result: Dict[NodeT, Set[NodeT]] = {} for node_key, successor_nodes in graph.items(): # Make sure every node is in the result even if it has no successors @@ -82,10 +103,10 @@ def reverse_graph(graph: Mapping[T, Collection[T]]) -> Dict[T, FrozenSet[T]]: # {{{ a_star def a_star( # pylint: disable=too-many-locals - initial_state: T, goal_state: T, neighbor_map: Mapping[T, Collection[T]], - estimate_remaining_cost: Optional[Callable[[T], float]] = None, - get_step_cost: Callable[[Any, T], float] = lambda x, y: 1 - ) -> List[T]: + initial_state: NodeT, goal_state: NodeT, neighbor_map: GraphT, + estimate_remaining_cost: Optional[Callable[[NodeT], float]] = None, + get_step_cost: Callable[[Any, NodeT], float] = lambda x, y: 1 + ) -> List[NodeT]: """ With the default cost and heuristic, this amounts to Dijkstra's algorithm. """ @@ -94,7 +115,7 @@ def a_star( # pylint: disable=too-many-locals if estimate_remaining_cost is None: # pylint: disable=function-redefined - def estimate_remaining_cost(x: T) -> float: + def estimate_remaining_cost(x: NodeT) -> float: if x != goal_state: return 1 else: @@ -103,7 +124,7 @@ def a_star( # pylint: disable=too-many-locals class AStarNode: __slots__ = ["state", "parent", "path_cost"] - def __init__(self, state: T, parent: Any, path_cost: float) -> None: + def __init__(self, state: NodeT, parent: Any, path_cost: float) -> None: self.state = state self.parent = parent self.path_cost = path_cost @@ -148,21 +169,21 @@ def a_star( # pylint: disable=too-many-locals # {{{ compute SCCs with Tarjan's algorithm -def compute_sccs(graph: Mapping[T, Collection[T]]) -> List[List[T]]: +def compute_sccs(graph: GraphT) -> List[List[NodeT]]: to_search = set(graph.keys()) - visit_order: Dict[T, int] = {} + visit_order: Dict[NodeT, int] = {} scc_root = {} sccs = [] while to_search: top = next(iter(to_search)) - call_stack: List[Tuple[T, Iterator[T], Optional[T]]] = [(top, + call_stack: List[Tuple[NodeT, Iterator[NodeT], Optional[NodeT]]] = [(top, iter(graph[top]), None)] visit_stack = [] visiting = set() - scc: List[T] = [] + scc: List[NodeT] = [] while call_stack: top, children, last_popped_child = call_stack.pop() @@ -215,7 +236,7 @@ class CycleError(Exception): :attr node: Node in a directed graph that is part of a cycle. """ - def __init__(self, node: T) -> None: + def __init__(self, node: NodeT) -> None: self.node = node @@ -227,7 +248,7 @@ class HeapEntry: Only needs to define :func:`pytools.graph.__lt__` according to . """ - def __init__(self, node: T, key: Any) -> None: + def __init__(self, node: NodeT, key: Any) -> None: self.node = node self.key = key @@ -235,15 +256,11 @@ class HeapEntry: return self.key < other.key -def compute_topological_order(graph: Mapping[T, Collection[T]], - key: Optional[Callable[[T], Any]] = None) -> List[T]: +def compute_topological_order(graph: GraphT, + key: Optional[Callable[[NodeT], Any]] = None) \ + -> List[NodeT]: """Compute a topological order of nodes in a directed graph. - :arg graph: A :class:`collections.abc.Mapping` representing a directed - graph. The dictionary contains one key representing each node in the - graph, and this key maps to a :class:`collections.abc.Collection` of its - successor nodes. - :arg key: A custom key function may be supplied to determine the order in break-even cases. Expects a function of one argument that is used to extract a comparison key from each node of the *graph*. @@ -310,13 +327,12 @@ def compute_topological_order(graph: Mapping[T, Collection[T]], # {{{ compute transitive closure -def compute_transitive_closure(graph: Mapping[T, MutableSet[T]]) -> ( - Mapping[T, MutableSet[T]]): +def compute_transitive_closure(graph: Mapping[NodeT, MutableSet[NodeT]]) -> GraphT: """Compute the transitive closure of a directed graph using Warshall's - algorithm. + algorithm. :arg graph: A :class:`collections.abc.Mapping` representing a directed - graph. The dictionary contains one key representing each node in the + graph. The mapping contains one key representing each node in the graph, and this key maps to a :class:`collections.abc.MutableSet` of nodes that are connected to the node by outgoing edges. This graph may contain cycles. This object must be picklable. Every graph node must @@ -346,14 +362,9 @@ def compute_transitive_closure(graph: Mapping[T, MutableSet[T]]) -> ( # {{{ check for cycle -def contains_cycle(graph: Mapping[T, Collection[T]]) -> bool: +def contains_cycle(graph: GraphT) -> bool: """Determine whether a graph contains a cycle. - :arg graph: A :class:`collections.abc.Mapping` representing a directed - graph. The dictionary contains one key representing each node in the - graph, and this key maps to a :class:`collections.abc.Collection` of - nodes that are connected to the node by outgoing edges. - :returns: A :class:`bool` indicating whether the graph contains a cycle. .. versionadded:: 2020.2 @@ -370,13 +381,13 @@ def contains_cycle(graph: Mapping[T, Collection[T]]) -> bool: # {{{ compute induced subgraph -def compute_induced_subgraph(graph: Mapping[T, Set[T]], - subgraph_nodes: Set[T]) -> Mapping[T, Set[T]]: +def compute_induced_subgraph(graph: Mapping[NodeT, Set[NodeT]], + subgraph_nodes: Set[NodeT]) -> GraphT: """Compute the induced subgraph formed by a subset of the vertices in a - graph. + graph. :arg graph: A :class:`collections.abc.Mapping` representing a directed - graph. The dictionary contains one key representing each node in the + graph. The mapping contains one key representing each node in the graph, and this key maps to a :class:`collections.abc.Set` of nodes that are connected to the node by outgoing edges. @@ -400,17 +411,12 @@ def compute_induced_subgraph(graph: Mapping[T, Set[T]], # {{{ validate graph -def validate_graph(graph: Mapping[T, Collection[T]]) -> None: +def validate_graph(graph: GraphT) -> None: """ Validates that all successor nodes of each node in *graph* are keys in *graph* itself. Raises a :class:`ValueError` if not. - - :arg graph: A :class:`collections.abc.Mapping` representing a directed - graph. The dictionary contains one key representing each node in the - graph, and this key maps to a :class:`collections.abc.Collection` of nodes - that are connected to the node by outgoing edges. """ - seen_nodes: Set[T] = set() + seen_nodes: Set[NodeT] = set() for children in graph.values(): seen_nodes.update(children) @@ -424,15 +430,12 @@ def validate_graph(graph: Mapping[T, Collection[T]]) -> None: # {{{ -def is_connected(graph: Mapping[T, Collection[T]]) -> bool: +def is_connected(graph: GraphT) -> bool: """ Returns whether all nodes in *graph* are connected, ignoring the edge direction. - :arg graph: A :class:`collections.abc.Mapping` representing a directed - graph. The dictionary contains one key representing each node in the - graph, and this key maps to a :class:`collections.abc.Collection` of nodes - that are connected to the node by outgoing edges. + :returns: A :class:`bool` indicating whether the graph is connected. """ if not graph: # https://cs.stackexchange.com/questions/52815/is-a-graph-of-zero-nodes-vertices-connected @@ -446,7 +449,7 @@ def is_connected(graph: Mapping[T, Collection[T]]) -> bool: for child in children: undirected_graph[child].add(node) - def dfs(node: T) -> None: + def dfs(node: NodeT) -> None: visited.add(node) for child in undirected_graph[node]: if child not in visited: