diff --git a/pytools/graph.py b/pytools/graph.py index b00988aa4983942e2600802af3d3ecba7f73f4c5..09c1b4121f1575d07e4371567d9a0ce1daccc973 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -29,6 +29,7 @@ __doc__ = """ Graph Algorithms ========================= +.. autofunction:: reverse_graph .. autofunction:: a_star .. autofunction:: compute_sccs .. autoclass:: CycleError @@ -45,13 +46,37 @@ Type Variables Used Any type. """ -from typing import (TypeVar, Mapping, Iterable, List, Optional, Any, Callable, - Set, MutableSet, Dict, Iterator, Tuple) +from typing import (Collection, TypeVar, Mapping, Iterable, List, Optional, Any, + Callable, Set, MutableSet, Dict, Iterator, Tuple, FrozenSet) T = TypeVar("T") +# {{{ reverse_graph + +def reverse_graph(graph: Mapping[T, Collection[T]]) -> Dict[T, FrozenSet[T]]: + """ + Reverses a graph. + + :param graph: A :class:`dict` representation of a directed graph, mapping each + node to other nodes to which it is connected by edges. + :returns: A :class:`dict` representing *graph* with edges reversed. + """ + result: Dict[T, Set[T]] = {} + + for node_key, successor_nodes in graph.items(): + # Make sure every node is in the result even if it has no successors + result.setdefault(node_key, set()) + + for successor in successor_nodes: + result.setdefault(successor, set()).add(node_key) + + return {k: frozenset(v) for k, v in result.items()} + +# }}} + + # {{{ a_star def a_star( # pylint: disable=too-many-locals diff --git a/test/test_graph_tools.py b/test/test_graph_tools.py index bdcd2fbdcc9764f515ee37c7f6ad33040781e14c..5c7a06e844a6c4bd96e23606f1c837c101067625 100644 --- a/test/test_graph_tools.py +++ b/test/test_graph_tools.py @@ -294,6 +294,24 @@ def test_prioritzed_topological_sort(): assert len(dep_graph) == 0 +def test_reverse_graph(): + graph = { + "a": frozenset(("b", "c")), + "b": frozenset(("d", "e")), + "c": frozenset(("d", "f")), + "d": frozenset(), + "e": frozenset(), + "f": frozenset(("g",)), + "g": frozenset(("h", "i", "j")), + "h": frozenset(), + "i": frozenset(), + "j": frozenset(), + } + + from pytools.graph import reverse_graph + assert graph == reverse_graph(reverse_graph(graph)) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])