diff --git a/pytools/graph.py b/pytools/graph.py index dac3c25a89dca2a3585af5bd64f8850668485643..fa2a12f8bf0b8f43fd916be905a84217b99a9d26 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -81,10 +81,10 @@ def reverse_graph(graph: Mapping[T, Collection[T]]) -> Dict[T, FrozenSet[T]]: # {{{ a_star def a_star( # pylint: disable=too-many-locals - initial_state, goal_state, neighbor_map, - estimate_remaining_cost=None, - get_step_cost=lambda x, y: 1 - ): + initial_state: T, goal_state: T, neighbor_map: Mapping[T, Collection[T]], + estimate_remaining_cost: Optional[Callable[[T], float]] = None, + get_step_cost: Callable[[Any, T], float] = lambda x, y: 1 + ) -> List[T]: """ With the default cost and heuristic, this amounts to Dijkstra's algorithm. """ @@ -92,7 +92,8 @@ def a_star( # pylint: disable=too-many-locals from heapq import heappop, heappush if estimate_remaining_cost is None: - def estimate_remaining_cost(x): # pylint: disable=function-redefined + # pylint: disable=function-redefined + def estimate_remaining_cost(x: T) -> float: if x != goal_state: return 1 else: @@ -101,7 +102,7 @@ def a_star( # pylint: disable=too-many-locals class AStarNode: __slots__ = ["state", "parent", "path_cost"] - def __init__(self, state, parent, path_cost): + def __init__(self, state: T, parent: Any, path_cost: float) -> None: self.state = state self.parent = parent self.path_cost = path_cost @@ -119,7 +120,7 @@ def a_star( # pylint: disable=too-many-locals if top.state == goal_state: result = [] - it = top + it: Optional[AStarNode] = top while it is not None: result.append(it.state) it = it.parent @@ -213,7 +214,7 @@ class CycleError(Exception): :attr node: Node in a directed graph that is part of a cycle. """ - def __init__(self, node) -> None: + def __init__(self, node: T) -> None: self.node = node @@ -225,11 +226,11 @@ class HeapEntry: Only needs to define :func:`pytools.graph.__lt__` according to <https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Lib/heapq.py#L135-L138>. """ - def __init__(self, node, key) -> None: + def __init__(self, node: T, key: Any) -> None: self.node = node self.key = key - def __lt__(self, other) -> bool: + def __lt__(self, other: "HeapEntry") -> bool: return self.key < other.key