From 2784aecb8d8361c03a168c322f7f03b6b9e276b6 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Mon, 14 Nov 2022 23:10:47 -0600 Subject: [PATCH] graph: more mypy annotations --- pytools/graph.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/pytools/graph.py b/pytools/graph.py index dac3c25..fa2a12f 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -81,10 +81,10 @@ def reverse_graph(graph: Mapping[T, Collection[T]]) -> Dict[T, FrozenSet[T]]: # {{{ a_star def a_star( # pylint: disable=too-many-locals - initial_state, goal_state, neighbor_map, - estimate_remaining_cost=None, - get_step_cost=lambda x, y: 1 - ): + initial_state: T, goal_state: T, neighbor_map: Mapping[T, Collection[T]], + estimate_remaining_cost: Optional[Callable[[T], float]] = None, + get_step_cost: Callable[[Any, T], float] = lambda x, y: 1 + ) -> List[T]: """ With the default cost and heuristic, this amounts to Dijkstra's algorithm. """ @@ -92,7 +92,8 @@ def a_star( # pylint: disable=too-many-locals from heapq import heappop, heappush if estimate_remaining_cost is None: - def estimate_remaining_cost(x): # pylint: disable=function-redefined + # pylint: disable=function-redefined + def estimate_remaining_cost(x: T) -> float: if x != goal_state: return 1 else: @@ -101,7 +102,7 @@ def a_star( # pylint: disable=too-many-locals class AStarNode: __slots__ = ["state", "parent", "path_cost"] - def __init__(self, state, parent, path_cost): + def __init__(self, state: T, parent: Any, path_cost: float) -> None: self.state = state self.parent = parent self.path_cost = path_cost @@ -119,7 +120,7 @@ def a_star( # pylint: disable=too-many-locals if top.state == goal_state: result = [] - it = top + it: Optional[AStarNode] = top while it is not None: result.append(it.state) it = it.parent @@ -213,7 +214,7 @@ class CycleError(Exception): :attr node: Node in a directed graph that is part of a cycle. """ - def __init__(self, node) -> None: + def __init__(self, node: T) -> None: self.node = node @@ -225,11 +226,11 @@ class HeapEntry: Only needs to define :func:`pytools.graph.__lt__` according to <https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Lib/heapq.py#L135-L138>. """ - def __init__(self, node, key) -> None: + def __init__(self, node: T, key: Any) -> None: self.node = node self.key = key - def __lt__(self, other) -> bool: + def __lt__(self, other: "HeapEntry") -> bool: return self.key < other.key -- GitLab