From 0f93fd0d3851b9c4ba1290d2437debea8941f98c Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Fri, 11 Nov 2022 18:57:03 -0600 Subject: [PATCH] Add validate_graph (#155) * add validate_graph * mypy fix * fix comment * better exception message --- pytools/__init__.py | 10 ++++++++-- pytools/graph.py | 25 +++++++++++++++++++++++++ test/test_graph_tools.py | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index e018f94..5f1203e 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -41,9 +41,15 @@ import math from sys import intern try: - from typing import SupportsIndex, ParamSpec, Concatenate + from typing import SupportsIndex, Concatenate except ImportError: - from typing_extensions import SupportsIndex, ParamSpec, Concatenate + from typing_extensions import SupportsIndex, Concatenate + +try: + from typing import ParamSpec +except ImportError: + from typing_extensions import ParamSpec # type: ignore[assignment] + # These are deprecated and will go away in 2022. all = builtins.all diff --git a/pytools/graph.py b/pytools/graph.py index 09c1b41..1eb9147 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -37,6 +37,7 @@ Graph Algorithms .. autofunction:: compute_transitive_closure .. autofunction:: contains_cycle .. autofunction:: compute_induced_subgraph +.. autofunction:: validate_graph Type Variables Used ------------------- @@ -394,4 +395,28 @@ def compute_induced_subgraph(graph: Mapping[T, Set[T]], # }}} + +# {{{ validate graph + +def validate_graph(graph: Mapping[T, Collection[T]]) -> None: + """ + Validates that all successor nodes of each node in *graph* are keys in + *graph* itself. Raises a :class:`ValueError` if not. + + :arg graph: A :class:`collections.abc.Mapping` representing a directed + graph. The dictionary contains one key representing each node in the + graph, and this key maps to a :class:`collections.abc.Collection` of nodes + that are connected to the node by outgoing edges. + """ + seen_nodes: Set[T] = set() + + for children in graph.values(): + seen_nodes.update(children) + + if not seen_nodes <= graph.keys(): + raise ValueError( + f"invalid graph, missing keys: {seen_nodes-graph.keys()}") + +# }}} + # vim: foldmethod=marker diff --git a/test/test_graph_tools.py b/test/test_graph_tools.py index 5c7a06e..4bb9b8f 100644 --- a/test/test_graph_tools.py +++ b/test/test_graph_tools.py @@ -312,6 +312,42 @@ def test_reverse_graph(): assert graph == reverse_graph(reverse_graph(graph)) +def test_validate_graph(): + from pytools.graph import validate_graph + graph1 = { + "d": set("c"), + "b": set("a"), + "a": set(), + "c": set("a"), + } + + validate_graph(graph1) + + graph2 = { + "d": set("d"), + "b": set("c"), + "a": set("b"), + "c": set("a"), + } + + validate_graph(graph2) + + graph3 = { + "a": {"b", "c"}, + "b": {"d", "e"}, + "c": {"d", "f"}, + "d": set(), + "e": set(), + "f": {"g"}, + "g": {"h", "i", "j"}, # h, i, j missing from keys + } + + with pytest.raises(ValueError): + validate_graph(graph3) + + validate_graph({}) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) -- GitLab