From 93eab0ff8d7d765aea7bb0424ab80b2280ecc8ce Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Fri, 7 Jun 2024 13:59:49 -0600 Subject: [PATCH] PersistentDict: concurrency improvements (#231) --- pytools/persistent_dict.py | 100 ++++++++++++++++----------- pytools/test/test_persistent_dict.py | 54 +++++++++++++++ setup.py | 1 + 3 files changed, 116 insertions(+), 39 deletions(-) diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 7af3e28..56dac06 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -463,21 +463,21 @@ class _PersistentDictBase(Mapping[K, V]): # https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions self.conn = sqlite3.connect(self.filename, isolation_level=None) - self.conn.execute( + self._exec_sql( "CREATE TABLE IF NOT EXISTS dict " "(keyhash TEXT NOT NULL PRIMARY KEY, key_value TEXT NOT NULL)" ) # https://www.sqlite.org/wal.html if enable_wal: - self.conn.execute("PRAGMA journal_mode = 'WAL'") + self._exec_sql("PRAGMA journal_mode = 'WAL'") # Note: the following configuration values were taken mostly from litedict: # https://github.com/litements/litedict/blob/377603fa597453ffd9997186a493ed4fd23e5399/litedict.py#L67-L70 # Use in-memory temp store # https://www.sqlite.org/pragma.html#pragma_temp_store - self.conn.execute("PRAGMA temp_store = 'MEMORY'") + self._exec_sql("PRAGMA temp_store = 'MEMORY'") # fsync() can be extremely slow on some systems. # See https://github.com/inducer/pytools/issues/227 for context. @@ -493,13 +493,13 @@ class _PersistentDictBase(Mapping[K, V]): "Pass 'safe_sync=False' if occasional data loss is tolerable. " "Pass 'safe_sync=True' to suppress this warning.", stacklevel=3) - self.conn.execute("PRAGMA synchronous = 'NORMAL'") + self._exec_sql("PRAGMA synchronous = 'NORMAL'") else: - self.conn.execute("PRAGMA synchronous = 'OFF'") + self._exec_sql("PRAGMA synchronous = 'OFF'") # 64 MByte of cache # https://www.sqlite.org/pragma.html#pragma_cache_size - self.conn.execute("PRAGMA cache_size = -64000") + self._exec_sql("PRAGMA cache_size = -64000") def __del__(self) -> None: if self.conn: @@ -522,6 +522,28 @@ class _PersistentDictBase(Mapping[K, V]): stored_key == key # pylint:disable=pointless-statement # noqa: B015 raise NoSuchEntryCollisionError(key) + def _exec_sql(self, *args: Any) -> sqlite3.Cursor: + return self._exec_sql_fn(lambda: self.conn.execute(*args)) + + def _exec_sql_fn(self, fn: Any) -> Any: + n = 0 + + while True: + n += 1 + try: + return fn() + except sqlite3.OperationalError as e: + # If the database is busy, retry + if (hasattr(e, "sqlite_errorcode") + and not e.sqlite_errorcode == sqlite3.SQLITE_BUSY): + raise + if n % 20 == 0: + from warnings import warn + warn(f"PersistentDict: database '{self.filename}' busy, {n} " + "retries") + else: + break + def store_if_not_present(self, key: K, value: V) -> None: """Store (*key*, *value*) if *key* is not already present.""" self.store(key, value, _skip_if_present=True) @@ -548,7 +570,7 @@ class _PersistentDictBase(Mapping[K, V]): def __len__(self) -> int: """Return the number of entries in the dictionary.""" - return next(self.conn.execute("SELECT COUNT(*) FROM dict"))[0] + return next(self._exec_sql("SELECT COUNT(*) FROM dict"))[0] def __iter__(self) -> Generator[K, None, None]: """Return an iterator over the keys in the dictionary.""" @@ -556,22 +578,22 @@ class _PersistentDictBase(Mapping[K, V]): def keys(self) -> Generator[K, None, None]: """Return an iterator over the keys in the dictionary.""" - for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"): + for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"): yield pickle.loads(row[0])[0] def values(self) -> Generator[V, None, None]: """Return an iterator over the values in the dictionary.""" - for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"): + for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"): yield pickle.loads(row[0])[1] def items(self) -> Generator[tuple[K, V], None, None]: """Return an iterator over the items in the dictionary.""" - for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"): + for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"): yield pickle.loads(row[0]) def nbytes(self) -> int: """Return the size of the dictionary in bytes.""" - return next(self.conn.execute("SELECT page_size * page_count FROM " + return next(self._exec_sql("SELECT page_size * page_count FROM " "pragma_page_size(), pragma_page_count()"))[0] def __repr__(self) -> str: @@ -580,7 +602,7 @@ class _PersistentDictBase(Mapping[K, V]): def clear(self) -> None: """Remove all entries from the dictionary.""" - self.conn.execute("DELETE FROM dict") + self._exec_sql("DELETE FROM dict") class WriteOncePersistentDict(_PersistentDictBase[K, V]): @@ -644,11 +666,11 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]): v = pickle.dumps((key, value)) if _skip_if_present: - self.conn.execute("INSERT OR IGNORE INTO dict VALUES (?, ?)", + self._exec_sql("INSERT OR IGNORE INTO dict VALUES (?, ?)", (keyhash, v)) else: try: - self.conn.execute("INSERT INTO dict VALUES (?, ?)", (keyhash, v)) + self._exec_sql("INSERT INTO dict VALUES (?, ?)", (keyhash, v)) except sqlite3.IntegrityError as e: if hasattr(e, "sqlite_errorcode"): if e.sqlite_errorcode == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY: @@ -662,7 +684,7 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]): def _fetch(self, keyhash: str) -> Tuple[K, V]: # pylint:disable=method-hidden # This method is separate from fetch() to allow for LRU caching - c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?", + c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?", (keyhash,)) row = c.fetchone() if row is None: @@ -730,17 +752,15 @@ class PersistentDict(_PersistentDictBase[K, V]): keyhash = self.key_builder(key) v = pickle.dumps((key, value)) - if _skip_if_present: - self.conn.execute("INSERT OR IGNORE INTO dict VALUES (?, ?)", - (keyhash, v)) - else: - self.conn.execute("INSERT OR REPLACE INTO dict VALUES (?, ?)", + mode = "IGNORE" if _skip_if_present else "REPLACE" + + self._exec_sql(f"INSERT OR {mode} INTO dict VALUES (?, ?)", (keyhash, v)) def fetch(self, key: K) -> V: keyhash = self.key_builder(key) - c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?", + c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?", (keyhash,)) row = c.fetchone() if row is None: @@ -754,24 +774,26 @@ class PersistentDict(_PersistentDictBase[K, V]): """Remove the entry associated with *key* from the dictionary.""" keyhash = self.key_builder(key) - self.conn.execute("BEGIN EXCLUSIVE TRANSACTION") - - try: - # This is split into SELECT/DELETE to allow for a collision check - c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?", - (keyhash,)) - row = c.fetchone() - if row is None: - raise NoSuchEntryError(key) - - stored_key, _value = pickle.loads(row[0]) - self._collision_check(key, stored_key) - - self.conn.execute("DELETE FROM dict WHERE keyhash=?", (keyhash,)) - self.conn.execute("COMMIT") - except Exception as e: - self.conn.execute("ROLLBACK") - raise e + def remove_inner() -> None: + self.conn.execute("BEGIN EXCLUSIVE TRANSACTION") + try: + # This is split into SELECT/DELETE to allow for a collision check + c = self.conn.execute("SELECT key_value FROM dict WHERE " + "keyhash=?", (keyhash,)) + row = c.fetchone() + if row is None: + raise NoSuchEntryError(key) + + stored_key, _value = pickle.loads(row[0]) + self._collision_check(key, stored_key) + + self.conn.execute("DELETE FROM dict WHERE keyhash=?", (keyhash,)) + self.conn.execute("COMMIT") + except Exception as e: + self.conn.execute("ROLLBACK") + raise e + + self._exec_sql_fn(remove_inner) def __delitem__(self, key: K) -> None: """Remove the entry associated with *key* from the dictionary.""" diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index bc67e66..2f6812c 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -899,6 +899,60 @@ def test_hash_function() -> None: # }}} +# {{{ basic concurrency test + +def _mp_fn(tmpdir: str) -> None: + import time + pdict: PersistentDict[int, int] = PersistentDict("pytools-test", + container_dir=tmpdir, + safe_sync=False) + n = 10000 + s = 0 + + start = time.time() + + for i in range(n): + if i % 100 == 0: + print(f"i={i}") + pdict[i] = i + + try: + s += pdict[i] + except NoSuchEntryError: + # Someone else already deleted the entry + pass + + try: + del pdict[i] + except NoSuchEntryError: + # Someone else already deleted the entry + pass + + end = time.time() + + print(f"PersistentDict: time taken to write {n} entries to " + f"{pdict.filename}: {end-start} s={s}") + + +def test_concurrency() -> None: + from multiprocessing import Process + + tmpdir = "_tmp/" # must be the same across all processes in this test + + try: + p = [Process(target=_mp_fn, args=(tmpdir, )) for _ in range(4)] + for pp in p: + pp.start() + for pp in p: + pp.join() + + assert all(pp.exitcode == 0 for pp in p), [pp.exitcode for pp in p] + finally: + shutil.rmtree(tmpdir) + +# }}} + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) diff --git a/setup.py b/setup.py index 165aa3a..f082237 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,7 @@ from setuptools import find_packages, setup + ver_dic = {} version_file = open("pytools/version.py") try: -- GitLab