diff --git a/pytools/graph.py b/pytools/graph.py index 2dbb09f256ee63cbe54e6ba442612479cbfb85c8..7f227bc61715a0674626b37593dbd037ecbca5ad 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -38,6 +38,7 @@ Graph Algorithms .. autofunction:: contains_cycle .. autofunction:: compute_induced_subgraph .. autofunction:: validate_graph +.. autofunction:: is_connected Type Variables Used ------------------- @@ -420,4 +421,40 @@ def validate_graph(graph: Mapping[T, Collection[T]]) -> None: # }}} + +# {{{ + +def is_connected(graph: Mapping[T, Collection[T]]) -> bool: + """ + Returns whether all nodes in *graph* are connected, ignoring + the edge direction. + + :arg graph: A :class:`collections.abc.Mapping` representing a directed + graph. The dictionary contains one key representing each node in the + graph, and this key maps to a :class:`collections.abc.Collection` of nodes + that are connected to the node by outgoing edges. + """ + if not graph: + # https://cs.stackexchange.com/questions/52815/is-a-graph-of-zero-nodes-vertices-connected + return True + + visited = set() + + undirected_graph = {node: set(children) for node, children in graph.items()} + + for node, children in graph.items(): + for child in children: + undirected_graph[child].add(node) + + def dfs(node: T) -> None: + visited.add(node) + for child in undirected_graph[node]: + if child not in visited: + dfs(child) + + dfs(next(iter(graph.keys()))) + + return visited == graph.keys() + + # vim: foldmethod=marker diff --git a/test/test_graph_tools.py b/test/test_graph_tools.py index 701cd5776b25959b4d26bbe8be54651f20c11104..2f14f1ada23809ef84e6c57a7cfa5bcd85170742 100644 --- a/test/test_graph_tools.py +++ b/test/test_graph_tools.py @@ -348,6 +348,53 @@ def test_validate_graph(): validate_graph({}) +def test_is_connected(): + from pytools.graph import is_connected + graph1 = { + "d": set("c"), + "b": set("a"), + "a": set(), + "c": set("a"), + } + + assert is_connected(graph1) + + graph2 = { + "d": set("d"), + "b": set("c"), + "a": set("b"), + "c": set("a"), + } + + assert not is_connected(graph2) + + graph3 = { + "a": {"b", "c"}, + "b": {"d", "e"}, + "c": {"d", "f"}, + "d": set(), + "e": set(), + "f": {"g"}, + "g": {}, + } + + assert is_connected(graph3) + + graph4 = { + "a": {"c"}, + "b": {"d", "e"}, + "c": {"f"}, + "d": set(), + "e": set(), + "f": {"g"}, + "g": {}, + } + + assert not is_connected(graph4) + + assert is_connected({}) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])