diff --git a/pytools/graph.py b/pytools/graph.py index 49c8f644b7c6699ffc48714d097d429f7580bfe5..a00fe3178c58d23deecb40bdcd00404a7cf25b22 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -35,6 +35,7 @@ Graph Algorithms .. autofunction:: compute_sccs .. autoclass:: CycleError .. autofunction:: compute_topological_order +.. autofunction:: compute_topological_order_v2 .. autofunction:: compute_transitive_closure .. autofunction:: contains_cycle .. autofunction:: compute_induced_subgraph @@ -178,63 +179,92 @@ class CycleError(Exception): self.node = node -def compute_topological_order(graph): - """Compute a toplogical order of nodes in a directed graph. +class HeapEntry: + """ + Helper class to compare associated keys while comparing the elements in + heap operations. + + Only needs to define :func:`pytools.graph.__lt__` according to + . + """ + def __init__(self, node, key): + self.node = node + self.key = key + + def __lt__(self, other): + return self.key < other.key + + +def compute_topological_order(graph, key=None): + """Compute a topological order of nodes in a directed graph. :arg graph: A :class:`collections.abc.Mapping` representing a directed graph. The dictionary contains one key representing each node in the - graph, and this key maps to a :class:`collections.abc.Iterable` of - nodes that are connected to the node by outgoing edges. + graph, and this key maps to a :class:`collections.abc.Iterable` of its + successor nodes. + + :arg key: A custom key function may be supplied to determine the order in + break-even cases. Expects a function of one argument that is used to + extract a comparison key from each node of the *graph*. :returns: A :class:`list` representing a valid topological ordering of the nodes in the directed graph. + .. note:: + + * Requires the keys of the mapping *graph* to be hashable. + * Implements `Kahn's algorithm `__. + .. versionadded:: 2020.2 """ + if key is None: + def key(x): + # all nodes have the same keys when not provided + return 0 - # find a valid ordering of graph nodes - reverse_order = [] - visited = set() - visiting = set() + from heapq import heapify, heappop, heappush - # go through each node - for root in graph: + order = [] - if root in visited: - # already encountered root as someone else's child - # and processed it at that time - continue + # {{{ compute nodes_to_num_predecessors - stack = [(root, iter(graph[root]))] - visiting.add(root) + nodes_to_num_predecessors = {node: 0 for node in graph} - while stack: - node, children = stack.pop() + for node in graph: + for child in graph[node]: + nodes_to_num_predecessors[child] = ( + nodes_to_num_predecessors.get(child, 0) + 1) - for child in children: - # note: each iteration removes child from children - if child in visiting: - raise CycleError(child) + # }}} - if child in visited: - continue + total_num_nodes = len(nodes_to_num_predecessors) - visiting.add(child) + # heap: list of instances of HeapEntry(n) where 'n' is a node in + # 'graph' with no predecessor. Nodes with no predecessors are the + # schedulable candidates. + heap = [HeapEntry(n, key(n)) + for n, num_preds in nodes_to_num_predecessors.items() + if num_preds == 0] + heapify(heap) - # put (node, remaining children) back on stack - stack.append((node, children)) + while heap: + # pick the node with least key + node_to_be_scheduled = heappop(heap).node + order.append(node_to_be_scheduled) - # put (child, grandchildren) on stack - stack.append((child, iter(graph.get(child, ())))) - break - else: - # loop did not break, - # so either this is a leaf or all children have been visited - visiting.remove(node) - visited.add(node) - reverse_order.append(node) + # discard 'node_to_be_scheduled' from the predecessors of its + # successors since it's been scheduled + for child in graph.get(node_to_be_scheduled, ()): + nodes_to_num_predecessors[child] -= 1 + if nodes_to_num_predecessors[child] == 0: + heappush(heap, HeapEntry(child, key(child))) + + if len(order) != total_num_nodes: + # any node which has a predecessor left is a part of a cycle + raise CycleError(next(iter(n for n, num_preds in + nodes_to_num_predecessors.items() if num_preds != 0))) - return list(reversed(reverse_order)) + return order # }}} diff --git a/test/test_graph_tools.py b/test/test_graph_tools.py index f1b349c3c05757fb2fc2b484b1d0b8f20c7b42d5..311f60d7449207bcb8849931ff91009dc140e2e7 100644 --- a/test/test_graph_tools.py +++ b/test/test_graph_tools.py @@ -222,6 +222,78 @@ def test_induced_subgraph(): assert subgraph == expected_subgraph +def test_prioritzed_topological_sort_examples(): + + from pytools.graph import compute_topological_order + + keys = {'a': 4, 'b': 3, 'c': 2, 'e': 1, 'd': 4} + dag = { + 'a': ['b', 'c'], + 'b': [], + 'c': ['d', 'e'], + 'd': [], + 'e': []} + + assert compute_topological_order(dag, key=keys.get) == [ + 'a', 'c', 'e', 'b', 'd'] + + keys = {'a': 7, 'b': 2, 'c': 1, 'd': 0} + dag = { + 'd': set('c'), + 'b': set('a'), + 'a': set(), + 'c': set('a'), + } + + assert compute_topological_order(dag, key=keys.get) == ['d', 'c', 'b', 'a'] + + +def test_prioritzed_topological_sort(): + + import random + from pytools.graph import compute_topological_order + rng = random.Random(0) + + def generate_random_graph(nnodes): + graph = dict((i, set()) for i in range(nnodes)) + for i in range(nnodes): + # to avoid cycles only consider edges node_i->node_j where j > i. + for j in range(i+1, nnodes): + # Edge probability 4/n: Generates decently interesting inputs. + if rng.randint(0, nnodes - 1) <= 2: + graph[i].add(j) + return graph + + nnodes = rng.randint(40, 100) + rev_dep_graph = generate_random_graph(nnodes) + dep_graph = {i: set() for i in range(nnodes)} + + for i in range(nnodes): + for rev_dep in rev_dep_graph[i]: + dep_graph[rev_dep].add(i) + + keys = [rng.random() for _ in range(nnodes)] + topo_order = compute_topological_order(rev_dep_graph, key=keys.__getitem__) + + for scheduled_node in topo_order: + nodes_with_no_deps = set(node for node, deps in dep_graph.items() + if len(deps) == 0) + + # check whether the order is a valid topological order + assert scheduled_node in nodes_with_no_deps + # check whether priorites are upheld + assert keys[scheduled_node] == min( + keys[node] for node in nodes_with_no_deps) + + # 'scheduled_node' is scheduled => no longer a dependency + dep_graph.pop(scheduled_node) + + for node, deps in dep_graph.items(): + deps.discard(scheduled_node) + + assert len(dep_graph) == 0 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])