From 07f72ce93bdf6e4a3197943fc59148e602ba704d Mon Sep 17 00:00:00 2001 From: Alexandru Fikl <alexfikl@gmail.com> Date: Mon, 24 Jun 2024 16:55:11 +0300 Subject: [PATCH] persistent_dict: fix typing annotations --- pytools/persistent_dict.py | 75 ++++++++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 24 deletions(-) diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 56dac06..89ef3a8 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -39,7 +39,8 @@ import sys from dataclasses import fields as dc_fields, is_dataclass from enum import Enum from typing import ( - TYPE_CHECKING, Any, Generator, Mapping, Optional, Protocol, Tuple, TypeVar) + TYPE_CHECKING, Any, Callable, FrozenSet, Iterator, Mapping, Optional, Protocol, + Tuple, TypeVar, cast) if TYPE_CHECKING: @@ -149,7 +150,7 @@ class KeyBuilder: # this exists so that we can (conceivably) switch algorithms at some point # down the road - new_hash = hashlib.sha256 + new_hash: Callable[..., Hash] = hashlib.sha256 def rec(self, key_hash: Hash, key: Any) -> Hash: """ @@ -281,11 +282,11 @@ class KeyBuilder: def update_for_bytes(key_hash: Hash, key: bytes) -> None: key_hash.update(key) - def update_for_tuple(self, key_hash: Hash, key: tuple) -> None: + def update_for_tuple(self, key_hash: Hash, key: Tuple[Any, ...]) -> None: for obj_i in key: self.rec(key_hash, obj_i) - def update_for_frozenset(self, key_hash: Hash, key: frozenset) -> None: + def update_for_frozenset(self, key_hash: Hash, key: FrozenSet[Any]) -> None: from pytools import unordered_hash unordered_hash( @@ -335,7 +336,7 @@ class KeyBuilder: self.rec(key_hash, fld.name) self.rec(key_hash, getattr(key, fld.name, None)) - def update_for_frozendict(self, key_hash: Hash, key: Mapping) -> None: + def update_for_frozendict(self, key_hash: Hash, key: Mapping[Any, Any]) -> None: from pytools import unordered_hash unordered_hash( @@ -423,12 +424,14 @@ def __getattr__(name: str) -> Any: raise AttributeError(name) +T = TypeVar("T") K = TypeVar("K") V = TypeVar("V") class _PersistentDictBase(Mapping[K, V]): - def __init__(self, identifier: str, + def __init__(self, + identifier: str, key_builder: Optional[KeyBuilder] = None, container_dir: Optional[str] = None, enable_wal: bool = False, @@ -523,9 +526,17 @@ class _PersistentDictBase(Mapping[K, V]): raise NoSuchEntryCollisionError(key) def _exec_sql(self, *args: Any) -> sqlite3.Cursor: - return self._exec_sql_fn(lambda: self.conn.execute(*args)) + def execute() -> sqlite3.Cursor: + assert self.conn is not None + return self.conn.execute(*args) + + cursor = self._exec_sql_fn(execute) + if not isinstance(cursor, sqlite3.Cursor): + raise RuntimeError("Failed to execute SQL statement") + + return cursor - def _exec_sql_fn(self, fn: Any) -> Any: + def _exec_sql_fn(self, fn: Callable[[], T]) -> Optional[T]: n = 0 while True: @@ -570,31 +581,36 @@ class _PersistentDictBase(Mapping[K, V]): def __len__(self) -> int: """Return the number of entries in the dictionary.""" - return next(self._exec_sql("SELECT COUNT(*) FROM dict"))[0] + result, = next(self._exec_sql("SELECT COUNT(*) FROM dict")) + assert isinstance(result, int) + return result - def __iter__(self) -> Generator[K, None, None]: + def __iter__(self) -> Iterator[K]: """Return an iterator over the keys in the dictionary.""" return self.keys() - def keys(self) -> Generator[K, None, None]: + def keys(self) -> Iterator[K]: # type: ignore[override] """Return an iterator over the keys in the dictionary.""" for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"): yield pickle.loads(row[0])[0] - def values(self) -> Generator[V, None, None]: + def values(self) -> Iterator[V]: # type: ignore[override] """Return an iterator over the values in the dictionary.""" for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"): yield pickle.loads(row[0])[1] - def items(self) -> Generator[tuple[K, V], None, None]: + def items(self) -> Iterator[Tuple[K, V]]: # type: ignore[override] """Return an iterator over the items in the dictionary.""" for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"): yield pickle.loads(row[0]) def nbytes(self) -> int: """Return the size of the dictionary in bytes.""" - return next(self._exec_sql("SELECT page_size * page_count FROM " - "pragma_page_size(), pragma_page_count()"))[0] + result, = next(self._exec_sql("SELECT page_size * page_count FROM " + "pragma_page_size(), pragma_page_count()")) + assert isinstance(result, int) + + return result def __repr__(self) -> str: """Return a string representation of the dictionary.""" @@ -648,10 +664,15 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]): :arg in_mem_cache_size: retain an in-memory cache of up to *in_mem_cache_size* items (with an LRU replacement policy) """ - _PersistentDictBase.__init__(self, identifier, key_builder, - container_dir, enable_wal, safe_sync) + super().__init__(identifier, + key_builder=key_builder, + container_dir=container_dir, + enable_wal=enable_wal, + safe_sync=safe_sync) + from functools import lru_cache - self._fetch = lru_cache(maxsize=in_mem_cache_size)(self._fetch) + + self._fetch = lru_cache(maxsize=in_mem_cache_size)(self._fetch_uncached) def clear_in_mem_cache(self) -> None: """ @@ -682,14 +703,16 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]): raise ReadOnlyEntryError("WriteOncePersistentDict, " "tried overwriting key") - def _fetch(self, keyhash: str) -> Tuple[K, V]: # pylint:disable=method-hidden + def _fetch_uncached(self, keyhash: str) -> Tuple[K, V]: # This method is separate from fetch() to allow for LRU caching c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?", (keyhash,)) row = c.fetchone() if row is None: raise KeyError - return pickle.loads(row[0]) + + key, value = pickle.loads(row[0]) + return key, value def fetch(self, key: K) -> V: keyhash = self.key_builder(key) @@ -703,7 +726,7 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]): return value def clear(self) -> None: - _PersistentDictBase.clear(self) + super().clear() self._fetch.cache_clear() @@ -745,8 +768,11 @@ class PersistentDict(_PersistentDictBase[K, V]): is faster than the default rollback journal mode, but it is not compatible with network filesystems. """ - _PersistentDictBase.__init__(self, identifier, key_builder, - container_dir, enable_wal, safe_sync) + super().__init__(identifier, + key_builder=key_builder, + container_dir=container_dir, + enable_wal=enable_wal, + safe_sync=safe_sync) def store(self, key: K, value: V, _skip_if_present: bool = False) -> None: keyhash = self.key_builder(key) @@ -768,13 +794,14 @@ class PersistentDict(_PersistentDictBase[K, V]): stored_key, value = pickle.loads(row[0]) self._collision_check(key, stored_key) - return value + return cast(V, value) def remove(self, key: K) -> None: """Remove the entry associated with *key* from the dictionary.""" keyhash = self.key_builder(key) def remove_inner() -> None: + assert self.conn is not None self.conn.execute("BEGIN EXCLUSIVE TRANSACTION") try: # This is split into SELECT/DELETE to allow for a collision check -- GitLab