From 07f72ce93bdf6e4a3197943fc59148e602ba704d Mon Sep 17 00:00:00 2001
From: Alexandru Fikl <alexfikl@gmail.com>
Date: Mon, 24 Jun 2024 16:55:11 +0300
Subject: [PATCH] persistent_dict: fix typing annotations

---
 pytools/persistent_dict.py | 75 ++++++++++++++++++++++++++------------
 1 file changed, 51 insertions(+), 24 deletions(-)

diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py
index 56dac06..89ef3a8 100644
--- a/pytools/persistent_dict.py
+++ b/pytools/persistent_dict.py
@@ -39,7 +39,8 @@ import sys
 from dataclasses import fields as dc_fields, is_dataclass
 from enum import Enum
 from typing import (
-    TYPE_CHECKING, Any, Generator, Mapping, Optional, Protocol, Tuple, TypeVar)
+    TYPE_CHECKING, Any, Callable, FrozenSet, Iterator, Mapping, Optional, Protocol,
+    Tuple, TypeVar, cast)
 
 
 if TYPE_CHECKING:
@@ -149,7 +150,7 @@ class KeyBuilder:
 
     # this exists so that we can (conceivably) switch algorithms at some point
     # down the road
-    new_hash = hashlib.sha256
+    new_hash: Callable[..., Hash] = hashlib.sha256
 
     def rec(self, key_hash: Hash, key: Any) -> Hash:
         """
@@ -281,11 +282,11 @@ class KeyBuilder:
     def update_for_bytes(key_hash: Hash, key: bytes) -> None:
         key_hash.update(key)
 
-    def update_for_tuple(self, key_hash: Hash, key: tuple) -> None:
+    def update_for_tuple(self, key_hash: Hash, key: Tuple[Any, ...]) -> None:
         for obj_i in key:
             self.rec(key_hash, obj_i)
 
-    def update_for_frozenset(self, key_hash: Hash, key: frozenset) -> None:
+    def update_for_frozenset(self, key_hash: Hash, key: FrozenSet[Any]) -> None:
         from pytools import unordered_hash
 
         unordered_hash(
@@ -335,7 +336,7 @@ class KeyBuilder:
             self.rec(key_hash, fld.name)
             self.rec(key_hash, getattr(key, fld.name, None))
 
-    def update_for_frozendict(self, key_hash: Hash, key: Mapping) -> None:
+    def update_for_frozendict(self, key_hash: Hash, key: Mapping[Any, Any]) -> None:
         from pytools import unordered_hash
 
         unordered_hash(
@@ -423,12 +424,14 @@ def __getattr__(name: str) -> Any:
     raise AttributeError(name)
 
 
+T = TypeVar("T")
 K = TypeVar("K")
 V = TypeVar("V")
 
 
 class _PersistentDictBase(Mapping[K, V]):
-    def __init__(self, identifier: str,
+    def __init__(self,
+                 identifier: str,
                  key_builder: Optional[KeyBuilder] = None,
                  container_dir: Optional[str] = None,
                  enable_wal: bool = False,
@@ -523,9 +526,17 @@ class _PersistentDictBase(Mapping[K, V]):
             raise NoSuchEntryCollisionError(key)
 
     def _exec_sql(self, *args: Any) -> sqlite3.Cursor:
-        return self._exec_sql_fn(lambda: self.conn.execute(*args))
+        def execute() -> sqlite3.Cursor:
+            assert self.conn is not None
+            return self.conn.execute(*args)
+
+        cursor = self._exec_sql_fn(execute)
+        if not isinstance(cursor, sqlite3.Cursor):
+            raise RuntimeError("Failed to execute SQL statement")
+
+        return cursor
 
-    def _exec_sql_fn(self, fn: Any) -> Any:
+    def _exec_sql_fn(self, fn: Callable[[], T]) -> Optional[T]:
         n = 0
 
         while True:
@@ -570,31 +581,36 @@ class _PersistentDictBase(Mapping[K, V]):
 
     def __len__(self) -> int:
         """Return the number of entries in the dictionary."""
-        return next(self._exec_sql("SELECT COUNT(*) FROM dict"))[0]
+        result, = next(self._exec_sql("SELECT COUNT(*) FROM dict"))
+        assert isinstance(result, int)
+        return result
 
-    def __iter__(self) -> Generator[K, None, None]:
+    def __iter__(self) -> Iterator[K]:
         """Return an iterator over the keys in the dictionary."""
         return self.keys()
 
-    def keys(self) -> Generator[K, None, None]:
+    def keys(self) -> Iterator[K]:  # type: ignore[override]
         """Return an iterator over the keys in the dictionary."""
         for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
             yield pickle.loads(row[0])[0]
 
-    def values(self) -> Generator[V, None, None]:
+    def values(self) -> Iterator[V]:  # type: ignore[override]
         """Return an iterator over the values in the dictionary."""
         for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
             yield pickle.loads(row[0])[1]
 
-    def items(self) -> Generator[tuple[K, V], None, None]:
+    def items(self) -> Iterator[Tuple[K, V]]:  # type: ignore[override]
         """Return an iterator over the items in the dictionary."""
         for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
             yield pickle.loads(row[0])
 
     def nbytes(self) -> int:
         """Return the size of the dictionary in bytes."""
-        return next(self._exec_sql("SELECT page_size * page_count FROM "
-                          "pragma_page_size(), pragma_page_count()"))[0]
+        result, = next(self._exec_sql("SELECT page_size * page_count FROM "
+                                      "pragma_page_size(), pragma_page_count()"))
+        assert isinstance(result, int)
+
+        return result
 
     def __repr__(self) -> str:
         """Return a string representation of the dictionary."""
@@ -648,10 +664,15 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]):
         :arg in_mem_cache_size: retain an in-memory cache of up to
             *in_mem_cache_size* items (with an LRU replacement policy)
         """
-        _PersistentDictBase.__init__(self, identifier, key_builder,
-                                     container_dir, enable_wal, safe_sync)
+        super().__init__(identifier,
+                         key_builder=key_builder,
+                         container_dir=container_dir,
+                         enable_wal=enable_wal,
+                         safe_sync=safe_sync)
+
         from functools import lru_cache
-        self._fetch = lru_cache(maxsize=in_mem_cache_size)(self._fetch)
+
+        self._fetch = lru_cache(maxsize=in_mem_cache_size)(self._fetch_uncached)
 
     def clear_in_mem_cache(self) -> None:
         """
@@ -682,14 +703,16 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]):
                     raise ReadOnlyEntryError("WriteOncePersistentDict, "
                                              "tried overwriting key")
 
-    def _fetch(self, keyhash: str) -> Tuple[K, V]:  # pylint:disable=method-hidden
+    def _fetch_uncached(self, keyhash: str) -> Tuple[K, V]:
         # This method is separate from fetch() to allow for LRU caching
         c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?",
                               (keyhash,))
         row = c.fetchone()
         if row is None:
             raise KeyError
-        return pickle.loads(row[0])
+
+        key, value = pickle.loads(row[0])
+        return key, value
 
     def fetch(self, key: K) -> V:
         keyhash = self.key_builder(key)
@@ -703,7 +726,7 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]):
             return value
 
     def clear(self) -> None:
-        _PersistentDictBase.clear(self)
+        super().clear()
         self._fetch.cache_clear()
 
 
@@ -745,8 +768,11 @@ class PersistentDict(_PersistentDictBase[K, V]):
             is faster than the default rollback journal mode, but it is
             not compatible with network filesystems.
         """
-        _PersistentDictBase.__init__(self, identifier, key_builder,
-                                     container_dir, enable_wal, safe_sync)
+        super().__init__(identifier,
+                         key_builder=key_builder,
+                         container_dir=container_dir,
+                         enable_wal=enable_wal,
+                         safe_sync=safe_sync)
 
     def store(self, key: K, value: V, _skip_if_present: bool = False) -> None:
         keyhash = self.key_builder(key)
@@ -768,13 +794,14 @@ class PersistentDict(_PersistentDictBase[K, V]):
 
         stored_key, value = pickle.loads(row[0])
         self._collision_check(key, stored_key)
-        return value
+        return cast(V, value)
 
     def remove(self, key: K) -> None:
         """Remove the entry associated with *key* from the dictionary."""
         keyhash = self.key_builder(key)
 
         def remove_inner() -> None:
+            assert self.conn is not None
             self.conn.execute("BEGIN EXCLUSIVE TRANSACTION")
             try:
                 # This is split into SELECT/DELETE to allow for a collision check
-- 
GitLab