diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 60131d8cfc52444b1f0175e0c062172d31a09eef..3e6ca3ef31d044d7e0b0f6f78f5ee2ed10b3245d 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -29,15 +29,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import errno + import hashlib import logging import os -import shutil +import pickle +import sqlite3 import sys from dataclasses import fields as dc_fields, is_dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Protocol, TypeVar +from typing import ( + TYPE_CHECKING, Any, Generator, Mapping, Optional, Protocol, Tuple, TypeVar) if TYPE_CHECKING: @@ -64,8 +66,6 @@ valid across interpreter invocations, unlike Python's built-in hashes. This module also provides a disk-backed dictionary that uses persistent hashing. .. autoexception:: NoSuchEntryError -.. autoexception:: NoSuchEntryInvalidKeyError -.. autoexception:: NoSuchEntryInvalidContentsError .. autoexception:: NoSuchEntryCollisionError .. autoexception:: ReadOnlyEntryError @@ -90,108 +90,6 @@ Internal stuff that is only here because the documentation tool wants it """ -# {{{ cleanup managers - -class CleanupBase: - pass - - -class CleanupManager(CleanupBase): - def __init__(self): - self.cleanups = [] - - def register(self, c): - self.cleanups.insert(0, c) - - def clean_up(self): - for c in self.cleanups: - c.clean_up() - - def error_clean_up(self): - for c in self.cleanups: - c.error_clean_up() - - -class LockManager(CleanupBase): - def __init__(self, cleanup_m, lock_file, stacklevel=0): - self.lock_file = lock_file - - attempts = 0 - while True: - try: - self.fd = os.open(self.lock_file, - os.O_CREAT | os.O_WRONLY | os.O_EXCL) - break - except OSError: - pass - - # This value was chosen based on the py-filelock package: - # https://github.com/tox-dev/py-filelock/blob/a6c8fabc4192fa7a4ae19b1875ee842ec5eb4f61/src/filelock/_api.py#L113 - wait_time_seconds = 0.05 - - # Warn every 10 seconds if not able to acquire lock - warn_attempts = int(10/wait_time_seconds) - - # Exit after 60 seconds if not able to acquire lock - exit_attempts = int(60/wait_time_seconds) - - from time import sleep - sleep(wait_time_seconds) - - attempts += 1 - - if attempts % warn_attempts == 0: - from warnings import warn - warn("could not obtain lock -- " - f"delete '{self.lock_file}' if necessary", - stacklevel=1 + stacklevel) - - if attempts > exit_attempts: - raise RuntimeError("waited more than one minute " - f"on the lock file '{self.lock_file}' " - "-- something is wrong") - - cleanup_m.register(self) - - def clean_up(self): - os.close(self.fd) - os.unlink(self.lock_file) - - def error_clean_up(self): - pass - - -class ItemDirManager(CleanupBase): - def __init__(self, cleanup_m, path, delete_on_error): - from os.path import isdir - - self.existed = isdir(path) - self.path = path - self.delete_on_error = delete_on_error - - cleanup_m.register(self) - - def reset(self): - try: - shutil.rmtree(self.path) - except OSError as e: - if e.errno != errno.ENOENT: - raise - - def mkdir(self): - from os import makedirs - makedirs(self.path, exist_ok=True) - - def clean_up(self): - pass - - def error_clean_up(self): - if self.delete_on_error: - self.reset() - -# }}} - - # {{{ key generation class Hash(Protocol): @@ -498,18 +396,6 @@ class NoSuchEntryError(KeyError): pass -class NoSuchEntryInvalidKeyError(NoSuchEntryError): - """Raised when an entry is not found in a :class:`PersistentDict` due to an - invalid key file.""" - pass - - -class NoSuchEntryInvalidContentsError(NoSuchEntryError): - """Raised when an entry is not found in a :class:`PersistentDict` due to an - invalid contents file.""" - pass - - class NoSuchEntryCollisionError(NoSuchEntryError): """Raised when an entry is not found in a :class:`PersistentDict`, but it contains an entry with the same hash key (hash collision).""" @@ -527,15 +413,27 @@ class CollisionWarning(UserWarning): pass +def __getattr__(name: str) -> Any: + if name in ("NoSuchEntryInvalidKeyError", + "NoSuchEntryInvalidContentsError"): + from warnings import warn + warn(f"pytools.persistent_dict.{name} has been removed.") + return NoSuchEntryError + + raise AttributeError(name) + + K = TypeVar("K") V = TypeVar("V") -class _PersistentDictBase(Generic[K, V]): +class _PersistentDictBase(Mapping[K, V]): def __init__(self, identifier: str, key_builder: Optional[KeyBuilder] = None, - container_dir: Optional[str] = None) -> None: + container_dir: Optional[str] = None, + enable_wal: bool = False) -> None: self.identifier = identifier + self.conn = None if key_builder is None: key_builder = KeyBuilder() @@ -549,112 +447,126 @@ class _PersistentDictBase(Generic[K, V]): if sys.platform == "darwin" and os.getenv("XDG_CACHE_HOME") is not None: # platformdirs does not handle XDG_CACHE_HOME on macOS # https://github.com/platformdirs/platformdirs/issues/269 - cache_dir = join(os.getenv("XDG_CACHE_HOME"), "pytools") + container_dir = join(os.getenv("XDG_CACHE_HOME"), "pytools") else: - cache_dir = platformdirs.user_cache_dir("pytools", "pytools") + container_dir = platformdirs.user_cache_dir("pytools", "pytools") - container_dir = join( - cache_dir, - "pdict-v4-{}-py{}".format( - identifier, - ".".join(str(i) for i in sys.version_info))) + self.filename = join(container_dir, f"pdict-v5-{identifier}" + + ".".join(str(i) for i in sys.version_info) + + ".sqlite") self.container_dir = container_dir - self._make_container_dir() - @staticmethod - def _warn(msg: str, category: Any = UserWarning, stacklevel: int = 0) -> None: - from warnings import warn - warn(msg, category, stacklevel=1 + stacklevel) + # isolation_level=None: enable autocommit mode + # https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions + self.conn = sqlite3.connect(self.filename, isolation_level=None) - def store_if_not_present(self, key: K, value: V, - _stacklevel: int = 0) -> None: - """Store (*key*, *value*) if *key* is not already present.""" - self.store(key, value, _skip_if_present=True, _stacklevel=1 + _stacklevel) + self.conn.execute( + "CREATE TABLE IF NOT EXISTS dict " + "(keyhash TEXT NOT NULL PRIMARY KEY, key_value TEXT NOT NULL)" + ) - def store(self, key: K, value: V, _skip_if_present: bool = False, - _stacklevel: int = 0) -> None: - """Store (*key*, *value*) in the dictionary.""" - raise NotImplementedError() + # https://www.sqlite.org/wal.html + if enable_wal: + self.conn.execute("PRAGMA journal_mode = 'WAL'") - def fetch(self, key: K, _stacklevel: int = 0) -> V: - """Return the value associated with *key* in the dictionary.""" - raise NotImplementedError() + # Note: the following configuration values were taken from litedict: + # https://github.com/litements/litedict/blob/377603fa597453ffd9997186a493ed4fd23e5399/litedict.py#L67-L70 + # They result in fast operations while maintaining database integrity + # even in the face of concurrent accesses and power loss. - @staticmethod - def _read(path: str) -> V: - from pickle import load - with open(path, "rb") as inf: - return load(inf) - - @staticmethod - def _write(path: str, value: V) -> None: - from pickle import HIGHEST_PROTOCOL, dump - with open(path, "wb") as outf: - dump(value, outf, protocol=HIGHEST_PROTOCOL) - - def _item_dir(self, hexdigest_key: str) -> str: - from os.path import join - - # Some file systems limit the number of directories in a directory. - # For ext4, that limit appears to be 64K for example. - # This doesn't solve that problem, but it makes it much less likely - - return join(self.container_dir, - hexdigest_key[:3], - hexdigest_key[3:6], - hexdigest_key[6:]) - - def _key_file(self, hexdigest_key: str) -> str: - from os.path import join - return join(self._item_dir(hexdigest_key), "key") + # temp_store=2: use in-memory temp store + # https://www.sqlite.org/pragma.html#pragma_temp_store + self.conn.execute("PRAGMA temp_store = 2") - def _contents_file(self, hexdigest_key: str) -> str: - from os.path import join - return join(self._item_dir(hexdigest_key), "contents") + # https://www.sqlite.org/pragma.html#pragma_synchronous + self.conn.execute("PRAGMA synchronous = NORMAL") - def _lock_file(self, hexdigest_key: str) -> str: - from os.path import join - return join(self.container_dir, str(hexdigest_key) + ".lock") + # 64 MByte of cache + # https://www.sqlite.org/pragma.html#pragma_cache_size + self.conn.execute("PRAGMA cache_size = -64000") - def _make_container_dir(self) -> None: - """Create the container directory to store the dictionary.""" - os.makedirs(self.container_dir, exist_ok=True) + def __del__(self) -> None: + if self.conn: + self.conn.close() - def _collision_check(self, key: K, stored_key: K, _stacklevel: int) -> None: + def _collision_check(self, key: K, stored_key: K) -> None: if stored_key != key: # Key collision, oh well. - self._warn(f"{self.identifier}: key collision in cache at " + from warnings import warn + warn(f"{self.identifier}: key collision in cache at " f"'{self.container_dir}' -- these are sufficiently unlikely " "that they're often indicative of a broken hash key " "implementation (that is not considering some elements " "relevant for equality comparison)", - CollisionWarning, - 1 + _stacklevel) + CollisionWarning + ) # This is here so we can step through equality comparison to # see what is actually non-equal. stored_key == key # pylint:disable=pointless-statement # noqa: B015 raise NoSuchEntryCollisionError(key) + def store_if_not_present(self, key: K, value: V) -> None: + """Store (*key*, *value*) if *key* is not already present.""" + self.store(key, value, _skip_if_present=True) + + def store(self, key: K, value: V, _skip_if_present: bool = False) -> None: + """Store (*key*, *value*) in the dictionary.""" + raise NotImplementedError() + + def fetch(self, key: K) -> V: + """Return the value associated with *key* in the dictionary.""" + raise NotImplementedError() + + def _make_container_dir(self) -> None: + """Create the container directory to store the dictionary.""" + os.makedirs(self.container_dir, exist_ok=True) + def __getitem__(self, key: K) -> V: """Return the value associated with *key* in the dictionary.""" - return self.fetch(key, _stacklevel=1) + return self.fetch(key) def __setitem__(self, key: K, value: V) -> None: """Store (*key*, *value*) in the dictionary.""" - self.store(key, value, _stacklevel=1) + self.store(key, value) + + def __len__(self) -> int: + """Return the number of entries in the dictionary.""" + return next(self.conn.execute("SELECT COUNT(*) FROM dict"))[0] + + def __iter__(self) -> Generator[K, None, None]: + """Return an iterator over the keys in the dictionary.""" + return self.keys() + + def keys(self) -> Generator[K, None, None]: + """Return an iterator over the keys in the dictionary.""" + for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"): + yield pickle.loads(row[0])[0] + + def values(self) -> Generator[V, None, None]: + """Return an iterator over the values in the dictionary.""" + for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"): + yield pickle.loads(row[0])[1] + + def items(self) -> Generator[tuple[K, V], None, None]: + """Return an iterator over the items in the dictionary.""" + for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"): + yield pickle.loads(row[0]) + + def nbytes(self) -> int: + """Return the size of the dictionary in bytes.""" + return next(self.conn.execute("SELECT page_size * page_count FROM " + "pragma_page_size(), pragma_page_count()"))[0] + + def __repr__(self) -> str: + """Return a string representation of the dictionary.""" + return f"{type(self).__name__}({self.filename}, nitems={len(self)})" def clear(self) -> None: """Remove all entries from the dictionary.""" - try: - shutil.rmtree(self.container_dir) - except OSError as e: - if e.errno != errno.ENOENT: - raise - - self._make_container_dir() + self.conn.execute("DELETE FROM dict") class WriteOncePersistentDict(_PersistentDictBase[K, V]): @@ -664,6 +576,13 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]): Compared with :class:`PersistentDict`, this class has faster retrieval times because it uses an LRU cache to cache entries in memory. + .. note:: + + This class intentionally does not store all values with a certain + key, based on the assumption that key conflicts are highly unlikely, + and if they occur, almost always due to a bug in the hash key + generation code (:class:`KeyBuilder`). + .. automethod:: __init__ .. automethod:: __getitem__ .. automethod:: __setitem__ @@ -676,19 +595,23 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]): def __init__(self, identifier: str, key_builder: Optional[KeyBuilder] = None, container_dir: Optional[str] = None, + enable_wal: bool = False, in_mem_cache_size: int = 256) -> None: """ - :arg identifier: a file-name-compatible string identifying this + :arg identifier: a filename-compatible string identifying this dictionary :arg key_builder: a subclass of :class:`KeyBuilder` :arg container_dir: the directory in which to store this dictionary. If ``None``, the default cache directory from :func:`platformdirs.user_cache_dir` is used + :arg enable_wal: enable write-ahead logging (WAL) mode. This mode + is faster than the default rollback journal mode, but it is + not compatible with network filesystems. :arg in_mem_cache_size: retain an in-memory cache of up to *in_mem_cache_size* items (with an LRU replacement policy) """ - _PersistentDictBase.__init__(self, identifier, key_builder, container_dir) - self._in_mem_cache_size = in_mem_cache_size + _PersistentDictBase.__init__(self, identifier, key_builder, + container_dir, enable_wal) from functools import lru_cache self._fetch = lru_cache(maxsize=in_mem_cache_size)(self._fetch) @@ -698,129 +621,38 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]): .. versionadded:: 2023.1.1 """ - self._fetch.cache_clear() - def _spin_until_removed(self, lock_file: str, stacklevel: int) -> None: - from os.path import exists - - attempts = 0 - while exists(lock_file): - from time import sleep - sleep(1) - - attempts += 1 - - if attempts > 10: - self._warn( - f"waiting until unlocked--delete '{lock_file}' if necessary", - stacklevel=1 + stacklevel) - - if attempts > 3 * 60: - raise RuntimeError("waited more than three minutes " - f"on the lock file '{lock_file}'" - "--something is wrong") - - def store(self, key: K, value: V, _skip_if_present: bool = False, - _stacklevel: int = 0) -> None: - hexdigest_key = self.key_builder(key) - - cleanup_m = CleanupManager() - try: - try: - LockManager(cleanup_m, self._lock_file(hexdigest_key), - 1 + _stacklevel) - item_dir_m = ItemDirManager( - cleanup_m, self._item_dir(hexdigest_key), - delete_on_error=False) - - if item_dir_m.existed: - if _skip_if_present: - return - raise ReadOnlyEntryError(key) - - item_dir_m.mkdir() - - key_path = self._key_file(hexdigest_key) - value_path = self._contents_file(hexdigest_key) - - self._write(value_path, value) - self._write(key_path, key) - - logger.debug("%s: disk cache store [key=%s]", - self.identifier, hexdigest_key) - except Exception: - cleanup_m.error_clean_up() - raise - finally: - cleanup_m.clean_up() - - def fetch(self, key: K, _stacklevel: int = 0) -> Any: - hexdigest_key = self.key_builder(key) - - (stored_key, stored_value) = self._fetch(hexdigest_key, 1 + _stacklevel) - - self._collision_check(key, stored_key, 1 + _stacklevel) - - return stored_value - - def _fetch(self, hexdigest_key: str, # pylint:disable=method-hidden - _stacklevel: int = 0) -> V: - # This is separate from fetch() to allow for LRU caching - - # {{{ check path exists and is unlocked - - item_dir = self._item_dir(hexdigest_key) - - from os.path import isdir - if not isdir(item_dir): - logger.debug("%s: disk cache miss [key=%s]", - self.identifier, hexdigest_key) - raise NoSuchEntryError(hexdigest_key) - - lock_file = self._lock_file(hexdigest_key) - self._spin_until_removed(lock_file, 1 + _stacklevel) - - # }}} - - key_file = self._key_file(hexdigest_key) - contents_file = self._contents_file(hexdigest_key) - - # Note: Unlike PersistentDict, this doesn't autodelete invalid entires, - # because that would lead to a race condition. - - # {{{ load key file and do equality check + def store(self, key: K, value: V, _skip_if_present: bool = False) -> None: + keyhash = self.key_builder(key) + v = pickle.dumps((key, value)) try: - read_key = self._read(key_file) - except Exception as e: - self._warn(f"{type(self).__name__}({self.identifier}) " - f"encountered an invalid key file for key {hexdigest_key}. " - f"Remove the directory '{item_dir}' if necessary. " - f"(caught: {type(e).__name__}: {e})", - stacklevel=1 + _stacklevel) - raise NoSuchEntryInvalidKeyError(hexdigest_key) - - # }}} - - logger.debug("%s: disk cache hit [key=%s]", - self.identifier, hexdigest_key) - - # {{{ load contents + self.conn.execute("INSERT INTO dict VALUES (?, ?)", (keyhash, v)) + except sqlite3.IntegrityError: + if not _skip_if_present: + raise ReadOnlyEntryError("WriteOncePersistentDict, " + "tried overwriting key") + + def _fetch(self, keyhash: str) -> Tuple[K, V]: # pylint:disable=method-hidden + # This method is separate from fetch() to allow for LRU caching + c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?", + (keyhash,)) + row = c.fetchone() + if row is None: + raise KeyError + return pickle.loads(row[0]) + + def fetch(self, key: K) -> V: + keyhash = self.key_builder(key) try: - read_contents = self._read(contents_file) - except Exception as e: - self._warn(f"{type(self).__name__}({self.identifier}) " - f"encountered an invalid contents file for key {hexdigest_key}. " - f"Remove the directory '{item_dir}' if necessary." - f"(caught: {type(e).__name__}: {e})", - stacklevel=1 + _stacklevel) - raise NoSuchEntryInvalidContentsError(hexdigest_key) - - # }}} - - return (read_key, read_contents) + stored_key, value = self._fetch(keyhash) + except KeyError: + raise NoSuchEntryError(key) + else: + self._collision_check(key, stored_key) + return value def clear(self) -> None: _PersistentDictBase.clear(self) @@ -830,6 +662,13 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]): class PersistentDict(_PersistentDictBase[K, V]): """A concurrent disk-backed dictionary. + .. note:: + + This class intentionally does not store all values with a certain + key, based on the assumption that key conflicts are highly unlikely, + and if they occur, almost always due to a bug in the hash key + generation code (:class:`KeyBuilder`). + .. automethod:: __init__ .. automethod:: __getitem__ .. automethod:: __setitem__ @@ -843,161 +682,72 @@ class PersistentDict(_PersistentDictBase[K, V]): def __init__(self, identifier: str, key_builder: Optional[KeyBuilder] = None, - container_dir: Optional[str] = None) -> None: + container_dir: Optional[str] = None, + enable_wal: bool = False) -> None: """ - :arg identifier: a file-name-compatible string identifying this + :arg identifier: a filename-compatible string identifying this dictionary :arg key_builder: a subclass of :class:`KeyBuilder` :arg container_dir: the directory in which to store this dictionary. If ``None``, the default cache directory from :func:`platformdirs.user_cache_dir` is used + :arg enable_wal: enable write-ahead logging (WAL) mode. This mode + is faster than the default rollback journal mode, but it is + not compatible with network filesystems. """ - _PersistentDictBase.__init__(self, identifier, key_builder, container_dir) + _PersistentDictBase.__init__(self, identifier, key_builder, + container_dir, enable_wal) - def store(self, key: K, value: V, _skip_if_present: bool = False, - _stacklevel: int = 0) -> None: - hexdigest_key = self.key_builder(key) + def store(self, key: K, value: V, _skip_if_present: bool = False) -> None: + keyhash = self.key_builder(key) + v = pickle.dumps((key, value)) - cleanup_m = CleanupManager() - try: - try: - LockManager(cleanup_m, self._lock_file(hexdigest_key), - 1 + _stacklevel) - item_dir_m = ItemDirManager( - cleanup_m, self._item_dir(hexdigest_key), - delete_on_error=True) - - if item_dir_m.existed: - if _skip_if_present: - return - item_dir_m.reset() - - item_dir_m.mkdir() - - key_path = self._key_file(hexdigest_key) - value_path = self._contents_file(hexdigest_key) - - self._write(value_path, value) - self._write(key_path, key) - - logger.debug("%s: cache store [key=%s]", - self.identifier, hexdigest_key) - except Exception: - cleanup_m.error_clean_up() - raise - finally: - cleanup_m.clean_up() - - def fetch(self, key: K, _stacklevel: int = 0) -> V: - hexdigest_key = self.key_builder(key) - item_dir = self._item_dir(hexdigest_key) - - from os.path import isdir - if not isdir(item_dir): - logger.debug("%s: cache miss [key=%s]", - self.identifier, hexdigest_key) - raise NoSuchEntryError(key) + if _skip_if_present: + self.conn.execute("INSERT OR IGNORE INTO dict VALUES (?, ?)", + (keyhash, v)) + else: + self.conn.execute("INSERT OR REPLACE INTO dict VALUES (?, ?)", + (keyhash, v)) - cleanup_m = CleanupManager() - try: - try: - LockManager(cleanup_m, self._lock_file(hexdigest_key), - 1 + _stacklevel) - item_dir_m = ItemDirManager( - cleanup_m, item_dir, delete_on_error=False) - - key_path = self._key_file(hexdigest_key) - value_path = self._contents_file(hexdigest_key) - - # {{{ load key - - try: - read_key = self._read(key_path) - except Exception as e: - item_dir_m.reset() - self._warn(f"{type(self).__name__}({self.identifier}) " - "encountered an invalid key file for key " - f"{hexdigest_key}. Entry deleted." - f"(caught: {type(e).__name__}: {e})", - stacklevel=1 + _stacklevel) - raise NoSuchEntryInvalidKeyError(key) - - self._collision_check(key, read_key, 1 + _stacklevel) - - # }}} - - logger.debug("%s: cache hit [key=%s]", - self.identifier, hexdigest_key) - - # {{{ load value - - try: - read_contents = self._read(value_path) - except Exception as e: - item_dir_m.reset() - self._warn(f"{type(self).__name__}({self.identifier}) " - "encountered an invalid contents file for key " - f"{hexdigest_key}. Entry deleted." - f"(caught: {type(e).__name__}: {e})", - stacklevel=1 + _stacklevel) - raise NoSuchEntryInvalidContentsError(key) - - return read_contents - - # }}} - - except Exception: - cleanup_m.error_clean_up() - raise - finally: - cleanup_m.clean_up() - - def remove(self, key: K, _stacklevel: int = 0) -> None: - """Remove the entry associated with *key* from the dictionary.""" - hexdigest_key = self.key_builder(key) + def fetch(self, key: K) -> V: + keyhash = self.key_builder(key) - item_dir = self._item_dir(hexdigest_key) - from os.path import isdir - if not isdir(item_dir): + c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?", + (keyhash,)) + row = c.fetchone() + if row is None: raise NoSuchEntryError(key) - cleanup_m = CleanupManager() - try: - try: - LockManager(cleanup_m, self._lock_file(hexdigest_key), - 1 + _stacklevel) - item_dir_m = ItemDirManager( - cleanup_m, item_dir, delete_on_error=False) - key_file = self._key_file(hexdigest_key) - - # {{{ load key - - try: - read_key = self._read(key_file) - except Exception as e: - item_dir_m.reset() - self._warn(f"{type(self).__name__}({self.identifier}) " - "encountered an invalid key file for key " - f"{hexdigest_key}. Entry deleted" - f"(caught: {type(e).__name__}: {e})", - stacklevel=1 + _stacklevel) - raise NoSuchEntryInvalidKeyError(key) + stored_key, value = pickle.loads(row[0]) + self._collision_check(key, stored_key) + return value - self._collision_check(key, read_key, 1 + _stacklevel) - - # }}} + def remove(self, key: K) -> None: + """Remove the entry associated with *key* from the dictionary.""" + keyhash = self.key_builder(key) - item_dir_m.reset() + self.conn.execute("BEGIN EXCLUSIVE TRANSACTION") - except Exception: - cleanup_m.error_clean_up() - raise - finally: - cleanup_m.clean_up() + try: + # This is split into SELECT/DELETE to allow for a collision check + c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?", + (keyhash,)) + row = c.fetchone() + if row is None: + raise NoSuchEntryError(key) + + stored_key, _value = pickle.loads(row[0]) + self._collision_check(key, stored_key) + + self.conn.execute("DELETE FROM dict WHERE keyhash=?", (keyhash,)) + self.conn.execute("COMMIT") + except Exception as e: + self.conn.execute("ROLLBACK") + raise e def __delitem__(self, key: K) -> None: """Remove the entry associated with *key* from the dictionary.""" - self.remove(key, _stacklevel=1) + self.remove(key) # }}} diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index f9450b6f80e016f3e408fa1f84f96ad45e92181b..5e8dfd41af4d3e000d402f7e9554eca617d12544 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -168,6 +168,9 @@ def test_persistent_dict_deletion() -> None: pdict[0] = 0 del pdict[0] + with pytest.raises(NoSuchEntryError): + pdict.remove(0) + with pytest.raises(NoSuchEntryError): pdict.fetch(0) @@ -719,6 +722,95 @@ def test_xdg_cache_home() -> None: shutil.rmtree(xdg_dir) +def test_speed(): + import time + + tmpdir = tempfile.mkdtemp() + pdict = WriteOncePersistentDict("pytools-test", container_dir=tmpdir) + + start = time.time() + for i in range(10000): + pdict[i] = i + end = time.time() + print("persistent dict write time: ", end-start) + + start = time.time() + for _ in range(5): + for i in range(10000): + pdict[i] + end = time.time() + print("persistent dict read time: ", end-start) + + shutil.rmtree(tmpdir) + + +def test_size(): + try: + tmpdir = tempfile.mkdtemp() + pdict = PersistentDict("pytools-test", container_dir=tmpdir) + + for i in range(10000): + pdict[f"foobarbazfoobbb{i}"] = i + + size = pdict.nbytes() + print("sqlite size: ", size/1024/1024, " MByte") + assert 1*1024*1024 < size < 2*1024*1024 + finally: + shutil.rmtree(tmpdir) + + +def test_len(): + try: + tmpdir = tempfile.mkdtemp() + pdict = PersistentDict("pytools-test", container_dir=tmpdir) + + assert len(pdict) == 0 + + for i in range(10000): + pdict[i] = i + + assert len(pdict) == 10000 + + pdict.clear() + + assert len(pdict) == 0 + finally: + shutil.rmtree(tmpdir) + + +def test_repr(): + try: + tmpdir = tempfile.mkdtemp() + pdict = PersistentDict("pytools-test", container_dir=tmpdir) + + assert repr(pdict)[:15] == "PersistentDict(" + finally: + shutil.rmtree(tmpdir) + + +def test_keys_values_items(): + try: + tmpdir = tempfile.mkdtemp() + pdict = PersistentDict("pytools-test", container_dir=tmpdir) + + for i in range(10000): + pdict[i] = i + + # This also tests deterministic iteration order + assert len(list(pdict.keys())) == 10000 == len(set(pdict.keys())) + assert list(pdict.keys()) == list(range(10000)) + assert list(pdict.values()) == list(range(10000)) + assert list(pdict.items()) == list(zip(list(pdict.keys()), range(10000))) + + assert ([k for k in pdict.keys()] # noqa: C416 + == list(pdict.keys()) + == list(pdict) + == [k for k in pdict]) # noqa: C416 + + finally: + shutil.rmtree(tmpdir) + + def global_fun(): pass