diff --git a/pytools/__init__.py b/pytools/__init__.py index e018f949836cef6420c50ecf75d24e6aed5fe864..5f1203e16ca54bab39519c6c85cd26e1b4b257a0 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -41,9 +41,15 @@ import math from sys import intern try: - from typing import SupportsIndex, ParamSpec, Concatenate + from typing import SupportsIndex, Concatenate except ImportError: - from typing_extensions import SupportsIndex, ParamSpec, Concatenate + from typing_extensions import SupportsIndex, Concatenate + +try: + from typing import ParamSpec +except ImportError: + from typing_extensions import ParamSpec # type: ignore[assignment] + # These are deprecated and will go away in 2022. all = builtins.all diff --git a/pytools/graph.py b/pytools/graph.py index 09c1b4121f1575d07e4371567d9a0ce1daccc973..1eb9147d40dd697d166e53b0ddac5898705a0719 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -37,6 +37,7 @@ Graph Algorithms .. autofunction:: compute_transitive_closure .. autofunction:: contains_cycle .. autofunction:: compute_induced_subgraph +.. autofunction:: validate_graph Type Variables Used ------------------- @@ -394,4 +395,28 @@ def compute_induced_subgraph(graph: Mapping[T, Set[T]], # }}} + +# {{{ validate graph + +def validate_graph(graph: Mapping[T, Collection[T]]) -> None: + """ + Validates that all successor nodes of each node in *graph* are keys in + *graph* itself. Raises a :class:`ValueError` if not. + + :arg graph: A :class:`collections.abc.Mapping` representing a directed + graph. The dictionary contains one key representing each node in the + graph, and this key maps to a :class:`collections.abc.Collection` of nodes + that are connected to the node by outgoing edges. + """ + seen_nodes: Set[T] = set() + + for children in graph.values(): + seen_nodes.update(children) + + if not seen_nodes <= graph.keys(): + raise ValueError( + f"invalid graph, missing keys: {seen_nodes-graph.keys()}") + +# }}} + # vim: foldmethod=marker diff --git a/test/test_graph_tools.py b/test/test_graph_tools.py index 5c7a06e844a6c4bd96e23606f1c837c101067625..4bb9b8ff9b1205fe846adc6821a1525ff027e27e 100644 --- a/test/test_graph_tools.py +++ b/test/test_graph_tools.py @@ -312,6 +312,42 @@ def test_reverse_graph(): assert graph == reverse_graph(reverse_graph(graph)) +def test_validate_graph(): + from pytools.graph import validate_graph + graph1 = { + "d": set("c"), + "b": set("a"), + "a": set(), + "c": set("a"), + } + + validate_graph(graph1) + + graph2 = { + "d": set("d"), + "b": set("c"), + "a": set("b"), + "c": set("a"), + } + + validate_graph(graph2) + + graph3 = { + "a": {"b", "c"}, + "b": {"d", "e"}, + "c": {"d", "f"}, + "d": set(), + "e": set(), + "f": {"g"}, + "g": {"h", "i", "j"}, # h, i, j missing from keys + } + + with pytest.raises(ValueError): + validate_graph(graph3) + + validate_graph({}) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])