diff --git a/pytools/graph.py b/pytools/graph.py index c9ce84aaf511e819b062ed20fec259b1182daab6..095d76a43d5f923e3215662f716a801798283626 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -38,8 +38,21 @@ Graph Algorithms .. autofunction:: compute_transitive_closure .. autofunction:: contains_cycle .. autofunction:: compute_induced_subgraph + +Type Variables Used +------------------- + +.. class:: T + + Any type. """ +from typing import (TypeVar, Mapping, Iterable, List, Optional, Any, Callable, + Set, MutableSet, Dict, Iterator, Tuple) + + +T = TypeVar("T") + # {{{ a_star @@ -109,19 +122,21 @@ def a_star( # pylint: disable=too-many-locals # {{{ compute SCCs with Tarjan's algorithm -def compute_sccs(graph): +def compute_sccs(graph: Mapping[T, Iterable[T]]) -> List[List[T]]: to_search = set(graph.keys()) - visit_order = {} + visit_order: Dict[T, int] = {} scc_root = {} sccs = [] while to_search: top = next(iter(to_search)) - call_stack = [(top, iter(graph[top]), None)] + call_stack: List[Tuple[T, Iterator[T], Optional[T]]] = [(top, + iter(graph[top]), + None)] visit_stack = [] visiting = set() - scc = [] + scc: List[T] = [] while call_stack: top, children, last_popped_child = call_stack.pop() @@ -194,7 +209,8 @@ class HeapEntry: return self.key < other.key -def compute_topological_order(graph, key=None): +def compute_topological_order(graph: Mapping[T, Iterable[T]], + key: Optional[Callable[[T], Any]] = None) -> List[T]: """Compute a topological order of nodes in a directed graph. :arg graph: A :class:`collections.abc.Mapping` representing a directed @@ -216,10 +232,8 @@ def compute_topological_order(graph, key=None): .. versionadded:: 2020.2 """ - if key is None: - def key(x): - # all nodes have the same keys when not provided - return 0 + # all nodes have the same keys when not provided + keyfunc = key if key is not None else (lambda x: 0) from heapq import heapify, heappop, heappush @@ -241,7 +255,7 @@ def compute_topological_order(graph, key=None): # heap: list of instances of HeapEntry(n) where 'n' is a node in # 'graph' with no predecessor. Nodes with no predecessors are the # schedulable candidates. - heap = [HeapEntry(n, key(n)) + heap = [HeapEntry(n, keyfunc(n)) for n, num_preds in nodes_to_num_predecessors.items() if num_preds == 0] heapify(heap) @@ -256,7 +270,7 @@ def compute_topological_order(graph, key=None): for child in graph.get(node_to_be_scheduled, ()): nodes_to_num_predecessors[child] -= 1 if nodes_to_num_predecessors[child] == 0: - heappush(heap, HeapEntry(child, key(child))) + heappush(heap, HeapEntry(child, keyfunc(child))) if len(order) != total_num_nodes: # any node which has a predecessor left is a part of a cycle @@ -270,7 +284,8 @@ def compute_topological_order(graph, key=None): # {{{ compute transitive closure -def compute_transitive_closure(graph): +def compute_transitive_closure(graph: Mapping[T, MutableSet[T]]) -> ( + Mapping[T, MutableSet[T]]): """Compute the transitive closure of a directed graph using Warshall's algorithm. @@ -305,7 +320,7 @@ def compute_transitive_closure(graph): # {{{ check for cycle -def contains_cycle(graph): +def contains_cycle(graph: Mapping[T, Iterable[T]]) -> bool: """Determine whether a graph contains a cycle. :arg graph: A :class:`collections.abc.Mapping` representing a directed @@ -329,7 +344,8 @@ def contains_cycle(graph): # {{{ compute induced subgraph -def compute_induced_subgraph(graph, subgraph_nodes): +def compute_induced_subgraph(graph: Mapping[T, Set[T]], + subgraph_nodes: Set[T]) -> Mapping[T, Set[T]]: """Compute the induced subgraph formed by a subset of the vertices in a graph.