diff --git a/doc/conf.py b/doc/conf.py index c87091c3ef34558f6bc5d70e82fd3f09a25bac0a..bf3d374a87bd8d4ca4967c1ef414ffebe9bba98a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -33,6 +33,7 @@ # ones. extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest', + 'sphinx.ext.intersphinx', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode'] @@ -114,7 +115,7 @@ html_sidebars = { # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +# html_static_path = ['_static'] # -- Options for HTMLHelp output ------------------------------------------ @@ -174,4 +175,9 @@ texinfo_documents = [ ] +intersphinx_mapping = { + 'http://docs.python.org/dev': None, + 'http://docs.scipy.org/doc/numpy/': None, + } + diff --git a/doc/graph.rst b/doc/graph.rst new file mode 100644 index 0000000000000000000000000000000000000000..7ad4b810ae952ccac984b185ff0c69e69545f691 --- /dev/null +++ b/doc/graph.rst @@ -0,0 +1 @@ +.. automodule:: pytools.graph diff --git a/doc/index.rst b/doc/index.rst index fb3dedd0fa7e8e984ef8c06d935d9893819fc6c8..d2860d0bd844c6935545785876cdc5572ed7b0a4 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -8,6 +8,7 @@ Welcome to pytools's documentation! reference obj_array persistent_dict + graph misc Indices and tables diff --git a/pytools/__init__.py b/pytools/__init__.py index 778976a81d49acf7c5b0d1a0f53afd0d19f1d440..d3dcb7722808e72c5166e1c3fc45a6e23992c1a1 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -92,11 +92,6 @@ Permutations, Tuples, Integer sequences .. autofunction:: generate_permutations .. autofunction:: generate_unique_permutations -Graph Algorithms ----------------- - -.. autofunction:: a_star - Formatting ---------- @@ -1369,68 +1364,27 @@ def get_write_to_map_from_permutation(original, permuted): # }}} -# {{{ graph algorithms - -def a_star( # pylint: disable=too-many-locals - initial_state, goal_state, neighbor_map, - estimate_remaining_cost=None, - get_step_cost=lambda x, y: 1 - ): - """ - With the default cost and heuristic, this amounts to Dijkstra's algorithm. - """ - - from heapq import heappop, heappush - - if estimate_remaining_cost is None: - def estimate_remaining_cost(x): # pylint: disable=function-redefined - if x != goal_state: - return 1 - else: - return 0 - - class AStarNode(object): - __slots__ = ["state", "parent", "path_cost"] - - def __init__(self, state, parent, path_cost): - self.state = state - self.parent = parent - self.path_cost = path_cost - - inf = float("inf") - init_remcost = estimate_remaining_cost(initial_state) - assert init_remcost != inf +# {{{ code maintenance - queue = [(init_remcost, AStarNode(initial_state, parent=None, path_cost=0))] - visited_states = set() +class MovedFunctionDeprecationWrapper: + def __init__(self, f): + self.f = f - while queue: - _, top = heappop(queue) - visited_states.add(top.state) + def __call__(self, *args, **kwargs): + from warnings import warn + warn("This function is deprecated. Use %s.%s instead." % ( + self.f.__module__, self.f.__name__), + DeprecationWarning, stacklevel=2) - if top.state == goal_state: - result = [] - it = top - while it is not None: - result.append(it.state) - it = it.parent - return result[::-1] + return self.f(*args, **kwargs) - for state in neighbor_map[top.state]: - if state in visited_states: - continue +# }}} - remaining_cost = estimate_remaining_cost(state) - if remaining_cost == inf: - continue - step_cost = get_step_cost(top, state) - estimated_path_cost = top.path_cost+step_cost+remaining_cost - heappush(queue, - (estimated_path_cost, - AStarNode(state, top, path_cost=top.path_cost + step_cost))) +# {{{ graph algorithms - raise RuntimeError("no solution") +from pytools.graph import a_star as a_star_moved +a_star = MovedFunctionDeprecationWrapper(a_star_moved) # }}} @@ -1639,23 +1593,6 @@ class CPyUserInterface(object): # }}} -# {{{ code maintenance - -class MovedFunctionDeprecationWrapper: - def __init__(self, f): - self.f = f - - def __call__(self, *args, **kwargs): - from warnings import warn - warn("This function is deprecated. Use %s.%s instead." % ( - self.f.__module__, self.f.__name__), - DeprecationWarning, stacklevel=2) - - return self.f(*args, **kwargs) - -# }}} - - # {{{ debugging class StderrToStdout(object): diff --git a/pytools/graph.py b/pytools/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..12797679c9ac2ed8c0c0532814a8c726d589ea2b --- /dev/null +++ b/pytools/graph.py @@ -0,0 +1,333 @@ +from __future__ import division, absolute_import, print_function + +__copyright__ = """ +Copyright (C) 2009-2013 Andreas Kloeckner +Copyright (C) 2020 Matt Wala +Copyright (C) 2020 James Stevens +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +__doc__ = """ +Graph Algorithms +========================= + +.. autofunction:: a_star +.. autofunction:: compute_sccs +.. autoclass:: CycleError +.. autofunction:: compute_topological_order +.. autofunction:: compute_transitive_closure +.. autofunction:: contains_cycle +.. autofunction:: compute_induced_subgraph +""" + + +# {{{ a_star + +def a_star( # pylint: disable=too-many-locals + initial_state, goal_state, neighbor_map, + estimate_remaining_cost=None, + get_step_cost=lambda x, y: 1 + ): + """ + With the default cost and heuristic, this amounts to Dijkstra's algorithm. + """ + + from heapq import heappop, heappush + + if estimate_remaining_cost is None: + def estimate_remaining_cost(x): # pylint: disable=function-redefined + if x != goal_state: + return 1 + else: + return 0 + + class AStarNode(object): + __slots__ = ["state", "parent", "path_cost"] + + def __init__(self, state, parent, path_cost): + self.state = state + self.parent = parent + self.path_cost = path_cost + + inf = float("inf") + init_remcost = estimate_remaining_cost(initial_state) + assert init_remcost != inf + + queue = [(init_remcost, AStarNode(initial_state, parent=None, path_cost=0))] + visited_states = set() + + while queue: + _, top = heappop(queue) + visited_states.add(top.state) + + if top.state == goal_state: + result = [] + it = top + while it is not None: + result.append(it.state) + it = it.parent + return result[::-1] + + for state in neighbor_map[top.state]: + if state in visited_states: + continue + + remaining_cost = estimate_remaining_cost(state) + if remaining_cost == inf: + continue + step_cost = get_step_cost(top, state) + + estimated_path_cost = top.path_cost+step_cost+remaining_cost + heappush(queue, + (estimated_path_cost, + AStarNode(state, top, path_cost=top.path_cost + step_cost))) + + raise RuntimeError("no solution") + +# }}} + + +# {{{ compute SCCs with Tarjan's algorithm + +def compute_sccs(graph): + to_search = set(graph.keys()) + visit_order = {} + scc_root = {} + sccs = [] + + while to_search: + top = next(iter(to_search)) + call_stack = [(top, iter(graph[top]), None)] + visit_stack = [] + visiting = set() + + scc = [] + + while call_stack: + top, children, last_popped_child = call_stack.pop() + + if top not in visiting: + # Unvisited: mark as visited, initialize SCC root. + count = len(visit_order) + visit_stack.append(top) + visit_order[top] = count + scc_root[top] = count + visiting.add(top) + to_search.discard(top) + + # Returned from a recursion, update SCC. + if last_popped_child is not None: + scc_root[top] = min( + scc_root[top], + scc_root[last_popped_child]) + + for child in children: + if child not in visit_order: + # Recurse. + call_stack.append((top, children, child)) + call_stack.append((child, iter(graph[child]), None)) + break + if child in visiting: + scc_root[top] = min( + scc_root[top], + visit_order[child]) + else: + if scc_root[top] == visit_order[top]: + scc = [] + while visit_stack[-1] != top: + scc.append(visit_stack.pop()) + scc.append(visit_stack.pop()) + for item in scc: + visiting.remove(item) + sccs.append(scc) + + return sccs + +# }}} + + +# {{{ compute topological order + +class CycleError(Exception): + """Raised when a topological ordering cannot be computed due to a cycle.""" + pass + + +def compute_topological_order(graph): + """Compute a toplogical order of nodes in a directed graph. + + :arg graph: A :class:`collections.abc.Mapping` representing a directed + graph. The dictionary contains one key representing each node in the + graph, and this key maps to a :class:`collections.abc.Iterable` of + nodes that are connected to the node by outgoing edges. + + :returns: A :class:`list` representing a valid topological ordering of the + nodes in the directed graph. + + .. versionadded:: 2020.2 + """ + + # find a valid ordering of graph nodes + reverse_order = [] + visited = set() + visiting = set() + + # go through each node + for root in graph: + + if root in visited: + # already encountered root as someone else's child + # and processed it at that time + continue + + stack = [(root, iter(graph[root]))] + visiting.add(root) + + while stack: + node, children = stack.pop() + + for child in children: + # note: each iteration removes child from children + if child in visiting: + raise CycleError() + + if child in visited: + continue + + visiting.add(child) + + # put (node, remaining children) back on stack + stack.append((node, children)) + + # put (child, grandchildren) on stack + stack.append((child, iter(graph.get(child, ())))) + break + else: + # loop did not break, + # so either this is a leaf or all children have been visited + visiting.remove(node) + visited.add(node) + reverse_order.append(node) + + return list(reversed(reverse_order)) + +# }}} + + +# {{{ compute transitive closure + +def compute_transitive_closure(graph): + """Compute the transitive closure of a directed graph using Warshall's + algorithm. + + :arg graph: A :class:`collections.abc.Mapping` representing a directed + graph. The dictionary contains one key representing each node in the + graph, and this key maps to a :class:`collections.abc.MutableSet` of + nodes that are connected to the node by outgoing edges. This graph may + contain cycles. This object must be picklable. Every graph node must + be included as a key in the graph. + + :returns: The transitive closure of the graph, represented using the same + data type. + + .. versionadded:: 2020.2 + """ + # Warshall's algorithm + + from copy import deepcopy + closure = deepcopy(graph) + + # (assumes all graph nodes are included in keys) + for k in graph.keys(): + for n1 in graph.keys(): + for n2 in graph.keys(): + if k in closure[n1] and n2 in closure[k]: + closure[n1].add(n2) + + return closure + +# }}} + + +# {{{ check for cycle + +def contains_cycle(graph): + """Determine whether a graph contains a cycle. + + :arg graph: A :class:`collections.abc.Mapping` representing a directed + graph. The dictionary contains one key representing each node in the + graph, and this key maps to a :class:`collections.abc.Iterable` of + nodes that are connected to the node by outgoing edges. + + :returns: A :class:`bool` indicating whether the graph contains a cycle. + + .. versionadded:: 2020.2 + """ + + try: + compute_topological_order(graph) + return False + except CycleError: + return True + +# }}} + + +# {{{ compute induced subgraph + +def compute_induced_subgraph(graph, subgraph_nodes): + """Compute the induced subgraph formed by a subset of the vertices in a + graph. + + :arg graph: A :class:`collections.abc.Mapping` representing a directed + graph. The dictionary contains one key representing each node in the + graph, and this key maps to a :class:`collections.abc.Set` of nodes + that are connected to the node by outgoing edges. + + :arg subgraph_nodes: A :class:`collections.abc.Set` containing a subset of + the graph nodes in the graph. + + :returns: A :class:`dict` representing the induced subgraph formed by + the subset of the vertices included in `subgraph_nodes`. + + .. versionadded:: 2020.2 + """ + + new_graph = {} + for node, children in graph.items(): + if node in subgraph_nodes: + new_graph[node] = children & subgraph_nodes + return new_graph + +# }}} + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() + +# vim: foldmethod=marker diff --git a/pytools/py_codegen.py b/pytools/py_codegen.py index e2a51442bcf91e00eb7027013082cc3dc15118a3..b2a0ef70804697db19eefd79ec056a65e16ed669 100644 --- a/pytools/py_codegen.py +++ b/pytools/py_codegen.py @@ -76,8 +76,8 @@ class PythonCodeGenerator(object): if "\n" in s: s = remove_common_indentation(s) - for l in s.split("\n"): - self.code.append(" "*(4*self.level) + l) + for line in s.split("\n"): + self.code.append(" "*(4*self.level) + line) def indent(self): self.level += 1 diff --git a/pytools/version.py b/pytools/version.py index ea2c3fd8fd40807205fb95b617a3c80dac0fae2e..0371d44ef9ee993b69c91277455f54766aab3c1d 100644 --- a/pytools/version.py +++ b/pytools/version.py @@ -1,3 +1,3 @@ -VERSION = (2020, 1) +VERSION = (2020, 2) VERSION_STATUS = "" VERSION_TEXT = ".".join(str(x) for x in VERSION) + VERSION_STATUS diff --git a/test/test_graph_tools.py b/test/test_graph_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b349c3c05757fb2fc2b484b1d0b8f20c7b42d5 --- /dev/null +++ b/test/test_graph_tools.py @@ -0,0 +1,230 @@ +import sys +import pytest + + +def test_compute_sccs(): + from pytools.graph import compute_sccs + import random + + rng = random.Random(0) + + def generate_random_graph(nnodes): + graph = dict((i, set()) for i in range(nnodes)) + for i in range(nnodes): + for j in range(nnodes): + # Edge probability 2/n: Generates decently interesting inputs. + if rng.randint(0, nnodes - 1) <= 1: + graph[i].add(j) + return graph + + def verify_sccs(graph, sccs): + visited = set() + + def visit(node): + if node in visited: + return [] + else: + visited.add(node) + result = [] + for child in graph[node]: + result = result + visit(child) + return result + [node] + + for scc in sccs: + scc = set(scc) + assert not scc & visited + # Check that starting from each element of the SCC results + # in the same set of reachable nodes. + for scc_root in scc: + visited.difference_update(scc) + result = visit(scc_root) + assert set(result) == scc, (set(result), scc) + + for nnodes in range(10, 20): + for i in range(40): + graph = generate_random_graph(nnodes) + verify_sccs(graph, compute_sccs(graph)) + + +def test_compute_topological_order(): + from pytools.graph import compute_topological_order, CycleError + + empty = {} + assert compute_topological_order(empty) == [] + + disconnected = {1: [], 2: [], 3: []} + assert len(compute_topological_order(disconnected)) == 3 + + line = list(zip(range(10), ([i] for i in range(1, 11)))) + import random + random.seed(0) + random.shuffle(line) + expected = list(range(11)) + assert compute_topological_order(dict(line)) == expected + + claw = {1: [2, 3], 0: [1]} + assert compute_topological_order(claw)[:2] == [0, 1] + + repeated_edges = {1: [2, 2], 2: [0]} + assert compute_topological_order(repeated_edges) == [1, 2, 0] + + self_cycle = {1: [1]} + with pytest.raises(CycleError): + compute_topological_order(self_cycle) + + cycle = {0: [2], 1: [2], 2: [3], 3: [4, 1]} + with pytest.raises(CycleError): + compute_topological_order(cycle) + + +def test_transitive_closure(): + from pytools.graph import compute_transitive_closure + + # simple test + graph = { + 1: set([2, ]), + 2: set([3, ]), + 3: set([4, ]), + 4: set(), + } + + expected_closure = { + 1: set([2, 3, 4, ]), + 2: set([3, 4, ]), + 3: set([4, ]), + 4: set(), + } + + closure = compute_transitive_closure(graph) + + assert closure == expected_closure + + # test with branches that reconnect + graph = { + 1: set([2, ]), + 2: set(), + 3: set([1, ]), + 4: set([1, ]), + 5: set([6, 7, ]), + 6: set([7, ]), + 7: set([1, ]), + 8: set([3, 4, ]), + } + + expected_closure = { + 1: set([2, ]), + 2: set(), + 3: set([1, 2, ]), + 4: set([1, 2, ]), + 5: set([1, 2, 6, 7, ]), + 6: set([1, 2, 7, ]), + 7: set([1, 2, ]), + 8: set([1, 2, 3, 4, ]), + } + + closure = compute_transitive_closure(graph) + + assert closure == expected_closure + + # test with cycles + graph = { + 1: set([2, ]), + 2: set([3, ]), + 3: set([4, ]), + 4: set([1, ]), + } + + expected_closure = { + 1: set([1, 2, 3, 4, ]), + 2: set([1, 2, 3, 4, ]), + 3: set([1, 2, 3, 4, ]), + 4: set([1, 2, 3, 4, ]), + } + + closure = compute_transitive_closure(graph) + + assert closure == expected_closure + + +def test_graph_cycle_finder(): + + from pytools.graph import contains_cycle + + graph = { + "a": set(["b", "c"]), + "b": set(["d", "e"]), + "c": set(["d", "f"]), + "d": set(), + "e": set(), + "f": set(["g", ]), + "g": set(), + } + + assert not contains_cycle(graph) + + graph = { + "a": set(["b", "c"]), + "b": set(["d", "e"]), + "c": set(["d", "f"]), + "d": set(), + "e": set(), + "f": set(["g", ]), + "g": set(["a", ]), + } + + assert contains_cycle(graph) + + graph = { + "a": set(["a", "c"]), + "b": set(["d", "e"]), + "c": set(["d", "f"]), + "d": set(), + "e": set(), + "f": set(["g", ]), + "g": set(), + } + + assert contains_cycle(graph) + + graph = { + "a": set(["a"]), + } + + assert contains_cycle(graph) + + +def test_induced_subgraph(): + + from pytools.graph import compute_induced_subgraph + + graph = { + "a": set(["b", "c"]), + "b": set(["d", "e"]), + "c": set(["d", "f"]), + "d": set(), + "e": set(), + "f": set(["g", ]), + "g": set(["h", "i", "j"]), + } + + node_subset = set(["b", "c", "e", "f", "g"]) + + expected_subgraph = { + "b": set(["e", ]), + "c": set(["f", ]), + "e": set(), + "f": set(["g", ]), + "g": set(), + } + + subgraph = compute_induced_subgraph(graph, node_subset) + + assert subgraph == expected_subgraph + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__])