diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py
index ed4fce98e8c522d19c4dfd347c510f3292828e63..5d1baf9f566127c8dce38c0252a3023e093014f1 100644
--- a/pytools/persistent_dict.py
+++ b/pytools/persistent_dict.py
@@ -472,9 +472,16 @@ class _PersistentDictBase(Mapping[K, V]):
         self.container_dir = container_dir
         self._make_container_dir()
 
-        # isolation_level=None: enable autocommit mode
-        # https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions
-        self.conn = sqlite3.connect(self.filename, isolation_level=None)
+        from threading import Lock
+        self.mutex = Lock()
+
+        # * isolation_level=None: enable autocommit mode
+        #   https://www.sqlite.org/lang_transaction.html#implicit_versus_explicit_transactions
+        # * check_same_thread=False: thread-level concurrency is handled by the
+        #   mutex above
+        self.conn = sqlite3.connect(self.filename,
+                                    isolation_level=None,
+                                    check_same_thread=False)
 
         self._exec_sql(
             "CREATE TABLE IF NOT EXISTS dict "
@@ -515,8 +522,9 @@ class _PersistentDictBase(Mapping[K, V]):
         self._exec_sql("PRAGMA cache_size = -64000")
 
     def __del__(self) -> None:
-        if self.conn:
-            self.conn.close()
+        with self.mutex:
+            if self.conn:
+                self.conn.close()
 
     def _collision_check(self, key: K, stored_key: K) -> None:
         if stored_key != key:
@@ -550,21 +558,22 @@ class _PersistentDictBase(Mapping[K, V]):
     def _exec_sql_fn(self, fn: Callable[[], T]) -> Optional[T]:
         n = 0
 
-        while True:
-            n += 1
-            try:
-                return fn()
-            except sqlite3.OperationalError as e:
-                # If the database is busy, retry
-                if (hasattr(e, "sqlite_errorcode")
-                        and not e.sqlite_errorcode == sqlite3.SQLITE_BUSY):
-                    raise
-                if n % 20 == 0:
-                    from warnings import warn
-                    warn(f"PersistentDict: database '{self.filename}' busy, {n} "
-                         "retries", stacklevel=3)
-            else:
-                break
+        with self.mutex:
+            while True:
+                n += 1
+                try:
+                    return fn()
+                except sqlite3.OperationalError as e:
+                    # If the database is busy, retry
+                    if (hasattr(e, "sqlite_errorcode")
+                            and not e.sqlite_errorcode == sqlite3.SQLITE_BUSY):
+                        raise
+                    if n % 20 == 0:
+                        from warnings import warn
+                        warn(f"PersistentDict: database '{self.filename}' busy, {n} "
+                             "retries", stacklevel=3)
+                else:
+                    break
 
     def store_if_not_present(self, key: K, value: V) -> None:
         """Store (*key*, *value*) if *key* is not already present."""
@@ -716,9 +725,19 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]):
 
     def _fetch_uncached(self, keyhash: str) -> Tuple[K, V]:
         # This method is separate from fetch() to allow for LRU caching
-        c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?",
-                              (keyhash,))
-        row = c.fetchone()
+
+        def fetch_inner() -> Optional[Tuple[Any]]:
+            assert self.conn is not None
+
+            # This is separate from fetch() so that the mutex covers the
+            # fetchone() call
+            c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
+                                (keyhash,))
+            res = c.fetchone()
+            assert res is None or isinstance(res, tuple)
+            return res
+
+        row = self._exec_sql_fn(fetch_inner)
         if row is None:
             raise KeyError
 
@@ -797,9 +816,19 @@ class PersistentDict(_PersistentDictBase[K, V]):
     def fetch(self, key: K) -> V:
         keyhash = self.key_builder(key)
 
-        c = self._exec_sql("SELECT key_value FROM dict WHERE keyhash=?",
-                              (keyhash,))
-        row = c.fetchone()
+        def fetch_inner() -> Optional[Tuple[Any]]:
+            assert self.conn is not None
+
+            # This is separate from fetch() so that the mutex covers the
+            # fetchone() call
+            c = self.conn.execute("SELECT key_value FROM dict WHERE keyhash=?",
+                                (keyhash,))
+            res = c.fetchone()
+            assert res is None or isinstance(res, tuple)
+            return res
+
+        row = self._exec_sql_fn(fetch_inner)
+
         if row is None:
             raise NoSuchEntryError(key)
 
diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py
index b0e050ed56e565c0c787364ae6e883e8dafc6660..858f22fb43e735cd593c922548963305a9f0db67 100644
--- a/pytools/test/test_persistent_dict.py
+++ b/pytools/test/test_persistent_dict.py
@@ -3,7 +3,7 @@ import sys  # noqa
 import tempfile
 from dataclasses import dataclass
 from enum import Enum, IntEnum
-from typing import Any, Dict
+from typing import Any, Dict, Optional
 
 import pytest
 
@@ -905,22 +905,34 @@ def test_hash_function() -> None:
     # }}}
 
 
-# {{{ basic concurrency test
+# {{{ basic concurrency tests
 
-def _mp_fn(tmpdir: str) -> None:
+def _conc_fn(tmpdir: Optional[str] = None,
+             pdict: Optional[PersistentDict[int, int]] = None) -> None:
     import time
-    pdict: PersistentDict[int, int] = PersistentDict("pytools-test",
-                                                    container_dir=tmpdir,
-                                                    safe_sync=False)
+
+    assert (pdict is None) ^ (tmpdir is None)
+
+    if pdict is None:
+        pdict = PersistentDict("pytools-test",
+                                container_dir=tmpdir,
+                                safe_sync=False)
     n = 10000
     s = 0
 
     start = time.time()
 
     for i in range(n):
-        if i % 100 == 0:
+        if i % 1000 == 0:
             print(f"i={i}")
-        pdict[i] = i
+
+        if isinstance(pdict, WriteOncePersistentDict):
+            try:
+                pdict[i] = i
+            except ReadOnlyEntryError:
+                pass
+        else:
+            pdict[i] = i
 
         try:
             s += pdict[i]
@@ -928,11 +940,12 @@ def _mp_fn(tmpdir: str) -> None:
             # Someone else already deleted the entry
             pass
 
-        try:
-            del pdict[i]
-        except NoSuchEntryError:
-            # Someone else already deleted the entry
-            pass
+        if not isinstance(pdict, WriteOncePersistentDict):
+            try:
+                del pdict[i]
+            except NoSuchEntryError:
+                # Someone else already deleted the entry
+                pass
 
     end = time.time()
 
@@ -940,13 +953,15 @@ def _mp_fn(tmpdir: str) -> None:
         f"{pdict.filename}: {end-start} s={s}")
 
 
-def test_concurrency() -> None:
+def test_concurrency_processes() -> None:
     from multiprocessing import Process
 
-    tmpdir = "_tmp/"  # must be the same across all processes in this test
+    tmpdir = "_tmp_proc/"  # must be the same across all processes in this test
 
     try:
-        p = [Process(target=_mp_fn, args=(tmpdir, )) for _ in range(4)]
+        # multiprocessing needs to pickle function arguments, so we can't pass
+        # the PersistentDict object (which is unpicklable) directly.
+        p = [Process(target=_conc_fn, args=(tmpdir, None)) for _ in range(4)]
         for pp in p:
             pp.start()
         for pp in p:
@@ -956,6 +971,56 @@ def test_concurrency() -> None:
     finally:
         shutil.rmtree(tmpdir)
 
+
+from threading import Thread
+
+
+class RaisingThread(Thread):
+    def run(self) -> None:
+        self._exc = None
+        try:
+            super().run()
+        except Exception as e:
+            self._exc = e
+
+    def join(self, timeout: Optional[float] = None) -> None:
+        super().join(timeout=timeout)
+        if self._exc:
+            raise self._exc
+
+
+def test_concurrency_threads() -> None:
+    tmpdir = "_tmp_threads/"  # must be the same across all threads in this test
+
+    try:
+        # Share this pdict object among all threads to test thread safety
+        pdict: PersistentDict[int, int] = PersistentDict("pytools-test",
+                                                    container_dir=tmpdir,
+                                                    safe_sync=False)
+        t = [RaisingThread(target=_conc_fn, args=(None, pdict)) for _ in range(4)]
+        for tt in t:
+            tt.start()
+        for tt in t:
+            tt.join()
+            # Threads will raise in join() if they encountered an exception
+    finally:
+        shutil.rmtree(tmpdir)
+
+    try:
+        # Share this pdict object among all threads to test thread safety
+        pdict2: WriteOncePersistentDict[int, int] = WriteOncePersistentDict(
+                                                    "pytools-test",
+                                                    container_dir=tmpdir,
+                                                    safe_sync=False)
+        t = [RaisingThread(target=_conc_fn, args=(None, pdict2)) for _ in range(4)]
+        for tt in t:
+            tt.start()
+        for tt in t:
+            tt.join()
+            # Threads will raise in join() if they encountered an exception
+    finally:
+        shutil.rmtree(tmpdir)
+
 # }}}