From c15051b6243bcdbeb3bbf98224b40ae7468992eb Mon Sep 17 00:00:00 2001
From: Matthias Diener <matthias.diener@gmail.com>
Date: Wed, 5 Jun 2024 07:49:41 -0600
Subject: [PATCH] PersistentDict: smarter WOPD.store (#230)

Passes through IntegrityErrors that are not ReadOnlyEntryErrors

Co-authored-by: Alex Fikl <alexfikl@gmail.com>
---
 pytools/persistent_dict.py | 22 ++++++++++++++++------
 1 file changed, 16 insertions(+), 6 deletions(-)

diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py
index 27860dd..7af3e28 100644
--- a/pytools/persistent_dict.py
+++ b/pytools/persistent_dict.py
@@ -643,12 +643,22 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]):
         keyhash = self.key_builder(key)
         v = pickle.dumps((key, value))
 
-        try:
-            self.conn.execute("INSERT INTO dict VALUES (?, ?)", (keyhash, v))
-        except sqlite3.IntegrityError:
-            if not _skip_if_present:
-                raise ReadOnlyEntryError("WriteOncePersistentDict, "
-                                         "tried overwriting key")
+        if _skip_if_present:
+            self.conn.execute("INSERT OR IGNORE INTO dict VALUES (?, ?)",
+                              (keyhash, v))
+        else:
+            try:
+                self.conn.execute("INSERT INTO dict VALUES (?, ?)", (keyhash, v))
+            except sqlite3.IntegrityError as e:
+                if hasattr(e, "sqlite_errorcode"):
+                    if e.sqlite_errorcode == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY:
+                        raise ReadOnlyEntryError("WriteOncePersistentDict, "
+                                                 "tried overwriting key")
+                    else:
+                        raise
+                else:
+                    raise ReadOnlyEntryError("WriteOncePersistentDict, "
+                                             "tried overwriting key")
 
     def _fetch(self, keyhash: str) -> Tuple[K, V]:  # pylint:disable=method-hidden
         # This method is separate from fetch() to allow for LRU caching
-- 
GitLab