From 2784aecb8d8361c03a168c322f7f03b6b9e276b6 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Mon, 14 Nov 2022 23:10:47 -0600
Subject: [PATCH] graph: more mypy annotations

---
 pytools/graph.py | 21 +++++++++++----------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/pytools/graph.py b/pytools/graph.py
index dac3c25..fa2a12f 100644
--- a/pytools/graph.py
+++ b/pytools/graph.py
@@ -81,10 +81,10 @@ def reverse_graph(graph: Mapping[T, Collection[T]]) -> Dict[T, FrozenSet[T]]:
 # {{{ a_star
 
 def a_star(  # pylint: disable=too-many-locals
-        initial_state, goal_state, neighbor_map,
-        estimate_remaining_cost=None,
-        get_step_cost=lambda x, y: 1
-        ):
+        initial_state: T, goal_state: T, neighbor_map: Mapping[T, Collection[T]],
+        estimate_remaining_cost: Optional[Callable[[T], float]] = None,
+        get_step_cost: Callable[[Any, T], float] = lambda x, y: 1
+        ) -> List[T]:
     """
     With the default cost and heuristic, this amounts to Dijkstra's algorithm.
     """
@@ -92,7 +92,8 @@ def a_star(  # pylint: disable=too-many-locals
     from heapq import heappop, heappush
 
     if estimate_remaining_cost is None:
-        def estimate_remaining_cost(x):  # pylint: disable=function-redefined
+        # pylint: disable=function-redefined
+        def estimate_remaining_cost(x: T) -> float:
             if x != goal_state:
                 return 1
             else:
@@ -101,7 +102,7 @@ def a_star(  # pylint: disable=too-many-locals
     class AStarNode:
         __slots__ = ["state", "parent", "path_cost"]
 
-        def __init__(self, state, parent, path_cost):
+        def __init__(self, state: T, parent: Any, path_cost: float) -> None:
             self.state = state
             self.parent = parent
             self.path_cost = path_cost
@@ -119,7 +120,7 @@ def a_star(  # pylint: disable=too-many-locals
 
         if top.state == goal_state:
             result = []
-            it = top
+            it: Optional[AStarNode] = top
             while it is not None:
                 result.append(it.state)
                 it = it.parent
@@ -213,7 +214,7 @@ class CycleError(Exception):
 
     :attr node: Node in a directed graph that is part of a cycle.
     """
-    def __init__(self, node) -> None:
+    def __init__(self, node: T) -> None:
         self.node = node
 
 
@@ -225,11 +226,11 @@ class HeapEntry:
     Only needs to define :func:`pytools.graph.__lt__` according to
     <https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Lib/heapq.py#L135-L138>.
     """
-    def __init__(self, node, key) -> None:
+    def __init__(self, node: T, key: Any) -> None:
         self.node = node
         self.key = key
 
-    def __lt__(self, other) -> bool:
+    def __lt__(self, other: "HeapEntry") -> bool:
         return self.key < other.key
 
 
-- 
GitLab