From 93eab0ff8d7d765aea7bb0424ab80b2280ecc8ce Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Fri, 7 Jun 2024 13:59:49 -0600
Subject: [PATCH] PersistentDict: concurrency improvements (#231)

---
 pytools/persistent_dict.py           | 100 ++++++++++++++++-----------
 pytools/test/test_persistent_dict.py |  54 +++++++++++++++
 setup.py                             |   1 +
 3 files changed, 116 insertions(+), 39 deletions(-)

diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py
index 7af3e28..56dac06 100644
--- a/pytools/persistent_dict.py
+++ b/pytools/persistent_dict.py
@@ -463,21 +463,21 @@ class _PersistentDictBase(Mapping[K, V]):
         # https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions
         self.conn = sqlite3.connect(self.filename, isolation_level=None)
 
-        self.conn.execute(
+        self._exec_sql(
             "CREATE TABLE IF NOT EXISTS dict "
             "(keyhash TEXT NOT NULL PRIMARY KEY, key_value TEXT NOT NULL)"
             )
 
         # https://www.sqlite.org/wal.html
         if enable_wal:
-            self.conn.execute("PRAGMA journal_mode = 'WAL'")
+            self._exec_sql("PRAGMA journal_mode = 'WAL'")
 
         # Note: the following configuration values were taken mostly from litedict:
         # https://github.com/litements/litedict/blob/377603fa597453ffd9997186a493ed4fd23e5399/litedict.py#L67-L70
 
         # Use in-memory temp store
         # https://www.sqlite.org/pragma.html#pragma_temp_store
-        self.conn.execute("PRAGMA temp_store = 'MEMORY'")
+        self._exec_sql("PRAGMA temp_store = 'MEMORY'")
 
         # fsync() can be extremely slow on some systems.
         # See https://github.com/inducer/pytools/issues/227 for context.
@@ -493,13 +493,13 @@ class _PersistentDictBase(Mapping[K, V]):
                      "Pass 'safe_sync=False' if occasional data loss is tolerable. "
                      "Pass 'safe_sync=True' to suppress this warning.",
                      stacklevel=3)
-            self.conn.execute("PRAGMA synchronous = 'NORMAL'")
+            self._exec_sql("PRAGMA synchronous = 'NORMAL'")
         else:
-            self.conn.execute("PRAGMA synchronous = 'OFF'")
+            self._exec_sql("PRAGMA synchronous = 'OFF'")
 
         # 64 MByte of cache
         # https://www.sqlite.org/pragma.html#pragma_cache_size
-        self.conn.execute("PRAGMA cache_size = -64000")
+        self._exec_sql("PRAGMA cache_size = -64000")
 
     def __del__(self) -> None:
         if self.conn:
@@ -522,6 +522,28 @@ class _PersistentDictBase(Mapping[K, V]):
             stored_key == key  # pylint:disable=pointless-statement  # noqa: B015
             raise NoSuchEntryCollisionError(key)
 
+    def _exec_sql(self, *args: Any) -> sqlite3.Cursor:
+        return self._exec_sql_fn(lambda: self.conn.execute(*args))
+
+    def _exec_sql_fn(self, fn: Any) -> Any:
+        n = 0
+
+        while True:
+            n += 1
+            try:
+                return fn()
+            except sqlite3.OperationalError as e:
+                # If the database is busy, retry
+                if (hasattr(e, "sqlite_errorcode")
+                        and not e.sqlite_errorcode == sqlite3.SQLITE_BUSY):
+                    raise
+                if n % 20 == 0:
+                    from warnings import warn
+                    warn(f"PersistentDict: database '{self.filename}' busy, {n} "
+                         "retries")
+            else:
+                break
+
     def store_if_not_present(self, key: K, value: V) -> None:
         """Store (*key*, *value*) if *key* is not already present."""
         self.store(key, value, _skip_if_present=True)
@@ -548,7 +570,7 @@ class _PersistentDictBase(Mapping[K, V]):
 
     def __len__(self) -> int:
         """Return the number of entries in the dictionary."""
-        return next(self.conn.execute("SELECT COUNT(*) FROM dict"))[0]
+        return next(self._exec_sql("SELECT COUNT(*) FROM dict"))[0]
 
     def __iter__(self) -> Generator[K, None, None]:
         """Return an iterator over the keys in the dictionary."""
@@ -556,22 +578,22 @@ class _PersistentDictBase(Mapping[K, V]):
 
     def keys(self) -> Generator[K, None, None]:
         """Return an iterator over the keys in the dictionary."""
-        for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"):
+        for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
             yield pickle.loads(row[0])[0]
 
     def values(self) -> Generator[V, None, None]:
         """Return an iterator over the values in the dictionary."""
-        for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"):
+        for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
             yield pickle.loads(row[0])[1]
 
     def items(self) -> Generator[tuple[K, V], None, None]:
         """Return an iterator over the items in the dictionary."""
-        for row in self.conn.execute("SELECT key_value FROM dict ORDER BY rowid"):
+        for row in self._exec_sql("SELECT key_value FROM dict ORDER BY rowid"):
             yield pickle.loads(row[0])
 
     def nbytes(self) -> int:
         """Return the size of the dictionary in bytes."""
-        return next(self.conn.execute("SELECT page_size * page_count FROM "
+        return next(self._exec_sql("SELECT page_size * page_count FROM "
                           "pragma_page_size(), pragma_page_count()"))[0]
 
     def __repr__(self) -> str:
@@ -580,7 +602,7 @@ class _PersistentDictBase(Mapping[K, V]):
 
     def clear(self) -> None:
         """Remove all entries from the dictionary."""
-        self.conn.execute("DELETE FROM dict")
+        self._exec_sql("DELETE FROM dict")
 
 
 class WriteOncePersistentDict(_PersistentDictBase[K, V]):
@@ -644,11 +666,11 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]):
         v = pickle.dumps((key, value))
 
         if _skip_if_present:
-            self.conn.execute("INSERT OR IGNORE INTO dict VALUES (?, ?)",
+            self._exec_sql("INSERT OR IGNORE INTO dict VALUES (?, ?)",
                               (keyhash, v))
         else:
             try:
-                self.conn.execute("INSERT INTO dict VALUES (?, ?)", (keyhash, v))
+                self._exec_sql("INSERT INTO dict VALUES (?, ?)", (keyhash, v))
             except sqlite3.IntegrityError as e:
                 if hasattr(e, "sqlite_errorcode"):
                     if e.sqlite_errorcode == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY:
@@ -662,7 +684,7 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]):
 
     def _fetch(self, keyhash: str) -> Tuple[K, V]:  # pylint:disable=method-hidden
         # This method is separate from fetch() to allow for LRU caching
-        c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
+        c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?",
                               (keyhash,))
         row = c.fetchone()
         if row is None:
@@ -730,17 +752,15 @@ class PersistentDict(_PersistentDictBase[K, V]):
         keyhash = self.key_builder(key)
         v = pickle.dumps((key, value))
 
-        if _skip_if_present:
-            self.conn.execute("INSERT OR IGNORE INTO dict VALUES (?, ?)",
-                              (keyhash, v))
-        else:
-            self.conn.execute("INSERT OR REPLACE INTO dict VALUES (?, ?)",
+        mode = "IGNORE" if _skip_if_present else "REPLACE"
+
+        self._exec_sql(f"INSERT OR {mode} INTO dict VALUES (?, ?)",
                               (keyhash, v))
 
     def fetch(self, key: K) -> V:
         keyhash = self.key_builder(key)
 
-        c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
+        c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?",
                               (keyhash,))
         row = c.fetchone()
         if row is None:
@@ -754,24 +774,26 @@ class PersistentDict(_PersistentDictBase[K, V]):
         """Remove the entry associated with *key* from the dictionary."""
         keyhash = self.key_builder(key)
 
-        self.conn.execute("BEGIN EXCLUSIVE TRANSACTION")
-
-        try:
-            # This is split into SELECT/DELETE to allow for a collision check
-            c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
-                                (keyhash,))
-            row = c.fetchone()
-            if row is None:
-                raise NoSuchEntryError(key)
-
-            stored_key, _value = pickle.loads(row[0])
-            self._collision_check(key, stored_key)
-
-            self.conn.execute("DELETE FROM dict WHERE keyhash=?", (keyhash,))
-            self.conn.execute("COMMIT")
-        except Exception as e:
-            self.conn.execute("ROLLBACK")
-            raise e
+        def remove_inner() -> None:
+            self.conn.execute("BEGIN EXCLUSIVE TRANSACTION")
+            try:
+                # This is split into SELECT/DELETE to allow for a collision check
+                c = self.conn.execute("SELECT key_value FROM dict WHERE "
+                                        "keyhash=?", (keyhash,))
+                row = c.fetchone()
+                if row is None:
+                    raise NoSuchEntryError(key)
+
+                stored_key, _value = pickle.loads(row[0])
+                self._collision_check(key, stored_key)
+
+                self.conn.execute("DELETE FROM dict WHERE keyhash=?", (keyhash,))
+                self.conn.execute("COMMIT")
+            except Exception as e:
+                self.conn.execute("ROLLBACK")
+                raise e
+
+        self._exec_sql_fn(remove_inner)
 
     def __delitem__(self, key: K) -> None:
         """Remove the entry associated with *key* from the dictionary."""
diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py
index bc67e66..2f6812c 100644
--- a/pytools/test/test_persistent_dict.py
+++ b/pytools/test/test_persistent_dict.py
@@ -899,6 +899,60 @@ def test_hash_function() -> None:
     # }}}
 
 
+# {{{ basic concurrency test
+
+def _mp_fn(tmpdir: str) -> None:
+    import time
+    pdict: PersistentDict[int, int] = PersistentDict("pytools-test",
+                                                    container_dir=tmpdir,
+                                                    safe_sync=False)
+    n = 10000
+    s = 0
+
+    start = time.time()
+
+    for i in range(n):
+        if i % 100 == 0:
+            print(f"i={i}")
+        pdict[i] = i
+
+        try:
+            s += pdict[i]
+        except NoSuchEntryError:
+            # Someone else already deleted the entry
+            pass
+
+        try:
+            del pdict[i]
+        except NoSuchEntryError:
+            # Someone else already deleted the entry
+            pass
+
+    end = time.time()
+
+    print(f"PersistentDict: time taken to write {n} entries to "
+        f"{pdict.filename}: {end-start} s={s}")
+
+
+def test_concurrency() -> None:
+    from multiprocessing import Process
+
+    tmpdir = "_tmp/"  # must be the same across all processes in this test
+
+    try:
+        p = [Process(target=_mp_fn, args=(tmpdir, )) for _ in range(4)]
+        for pp in p:
+            pp.start()
+        for pp in p:
+            pp.join()
+
+        assert all(pp.exitcode == 0 for pp in p), [pp.exitcode for pp in p]
+    finally:
+        shutil.rmtree(tmpdir)
+
+# }}}
+
+
 if __name__ == "__main__":
     if len(sys.argv) > 1:
         exec(sys.argv[1])
diff --git a/setup.py b/setup.py
index 165aa3a..f082237 100644
--- a/setup.py
+++ b/setup.py
@@ -2,6 +2,7 @@
 
 from setuptools import find_packages, setup
 
+
 ver_dic = {}
 version_file = open("pytools/version.py")
 try:
-- 
GitLab