diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py
index 857299f17798ed51d34f4374ac2950a3ac4721ee..aaca40fabb0f50f7b8ee6bc7ed873b2b977e6d48 100644
--- a/pytools/persistent_dict.py
+++ b/pytools/persistent_dict.py
@@ -29,7 +29,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-import collections.abc as abc
 import errno
 import hashlib
 import logging
@@ -443,125 +442,6 @@ class KeyBuilder:
 # }}}
 
 
-# {{{ lru cache
-
-class _LinkedList:
-    """The list operates on nodes of the form [value, leftptr, rightpr]. To create a
-    node of this form you can use `LinkedList.new_node().`
-
-    Supports inserting at the left and deleting from an arbitrary location.
-    """
-    def __init__(self):
-        self.count = 0
-        self.head = None
-        self.end = None
-
-    @staticmethod
-    def new_node(element):
-        return [element, None, None]
-
-    def __len__(self):
-        return self.count
-
-    def appendleft_node(self, node):
-        self.count += 1
-
-        if self.head is None:
-            self.head = self.end = node
-            return
-
-        self.head[1] = node
-        node[2] = self.head
-
-        self.head = node
-
-    def pop_node(self):
-        end = self.end
-        self.remove_node(end)
-        return end
-
-    def remove_node(self, node):
-        self.count -= 1
-
-        if self.head is self.end:
-            assert node is self.head
-            self.head = self.end = None
-            return
-
-        left = node[1]
-        right = node[2]
-
-        if left is None:
-            self.head = right
-        else:
-            left[2] = right
-
-        if right is None:
-            self.end = left
-        else:
-            right[1] = left
-
-        node[1] = node[2] = None
-
-
-class _LRUCache(abc.MutableMapping):
-    """A mapping that keeps at most *maxsize* items with an LRU replacement policy.
-    """
-    def __init__(self, maxsize):
-        self.lru_order = _LinkedList()
-        self.maxsize = maxsize
-        self.cache = {}
-
-    def __delitem__(self, item):
-        node = self.cache[item]
-        self.lru_order.remove_node(node)
-        del self.cache[item]
-
-    def __getitem__(self, item):
-        node = self.cache[item]
-        self.lru_order.remove_node(node)
-        self.lru_order.appendleft_node(node)
-        # A linked list node contains a tuple of the form (item, value).
-        return node[0][1]
-
-    def __contains__(self, item):
-        return item in self.cache
-
-    def __iter__(self):
-        return iter(self.cache)
-
-    def __len__(self) -> int:
-        return len(self.cache)
-
-    def clear(self):
-        self.cache.clear()
-        self.lru_order = _LinkedList()
-
-    def __setitem__(self, item, value):
-        if self.maxsize < 1:
-            return
-
-        try:
-            node = self.cache[item]
-            self.lru_order.remove_node(node)
-        except KeyError:
-            if len(self.lru_order) >= self.maxsize:
-                # Make room for new elements.
-                end_node = self.lru_order.pop_node()
-                del self.cache[end_node[0][0]]
-
-            node = self.lru_order.new_node((item, value))
-            self.cache[item] = node
-
-        self.lru_order.appendleft_node(node)
-
-        assert len(self.cache) == len(self.lru_order), \
-                (len(self.cache), len(self.lru_order))
-        assert len(self.lru_order) <= self.maxsize
-
-# }}}
-
-
 # {{{ top-level
 
 class NoSuchEntryError(KeyError):
@@ -720,7 +600,7 @@ class WriteOncePersistentDict(_PersistentDictBase):
     """A concurrent disk-backed dictionary that disallows overwriting/deletion.
 
     Compared with :class:`PersistentDict`, this class has faster
-    retrieval times.
+    retrieval times because it uses an LRU cache to cache entries in memory.
 
     .. automethod:: __init__
     .. automethod:: __getitem__
@@ -742,14 +622,15 @@ class WriteOncePersistentDict(_PersistentDictBase):
         """
         _PersistentDictBase.__init__(self, identifier, key_builder, container_dir)
         self._in_mem_cache_size = in_mem_cache_size
-        self.clear_in_mem_cache()
+        from functools import lru_cache
+        self._fetch = lru_cache(maxsize=in_mem_cache_size)(self._fetch)
 
     def clear_in_mem_cache(self) -> None:
         """
         .. versionadded:: 2023.1.1
         """
 
-        self._cache = _LRUCache(self._in_mem_cache_size)
+        self._fetch.cache_clear()
 
     def _spin_until_removed(self, lock_file, stacklevel):
         from os.path import exists
@@ -807,19 +688,14 @@ class WriteOncePersistentDict(_PersistentDictBase):
     def fetch(self, key, _stacklevel=0):
         hexdigest_key = self.key_builder(key)
 
-        # {{{ in memory cache
+        (stored_key, stored_value) = self._fetch(hexdigest_key, 1 + _stacklevel)
 
-        try:
-            stored_key, stored_value = self._cache[hexdigest_key]
-        except KeyError:
-            pass
-        else:
-            logger.debug("%s: in mem cache hit [key=%s]",
-                    self.identifier, hexdigest_key)
-            self._collision_check(key, stored_key, 1 + _stacklevel)
-            return stored_value
+        self._collision_check(key, stored_key, 1 + _stacklevel)
 
-        # }}}
+        return stored_value
+
+    def _fetch(self, hexdigest_key, _stacklevel=0):  # pylint:disable=method-hidden
+        # This is separate from fetch() to allow for LRU caching
 
         # {{{ check path exists and is unlocked
 
@@ -829,7 +705,7 @@ class WriteOncePersistentDict(_PersistentDictBase):
         if not isdir(item_dir):
             logger.debug("%s: disk cache miss [key=%s]",
                     self.identifier, hexdigest_key)
-            raise NoSuchEntryError(key)
+            raise NoSuchEntryError(hexdigest_key)
 
         lock_file = self._lock_file(hexdigest_key)
         self._spin_until_removed(lock_file, 1 + _stacklevel)
@@ -852,9 +728,7 @@ class WriteOncePersistentDict(_PersistentDictBase):
                     f"Remove the directory '{item_dir}' if necessary. "
                     f"(caught: {type(e).__name__}: {e})",
                     stacklevel=1 + _stacklevel)
-            raise NoSuchEntryInvalidKeyError(key)
-
-        self._collision_check(key, read_key, 1 + _stacklevel)
+            raise NoSuchEntryInvalidKeyError(hexdigest_key)
 
         # }}}
 
@@ -871,16 +745,15 @@ class WriteOncePersistentDict(_PersistentDictBase):
                     f"Remove the directory '{item_dir}' if necessary."
                     f"(caught: {type(e).__name__}: {e})",
                     stacklevel=1 + _stacklevel)
-            raise NoSuchEntryInvalidContentsError(key)
+            raise NoSuchEntryInvalidContentsError(hexdigest_key)
 
         # }}}
 
-        self._cache[hexdigest_key] = (key, read_contents)
-        return read_contents
+        return (read_key, read_contents)
 
     def clear(self):
         _PersistentDictBase.clear(self)
-        self._cache.clear()
+        self._fetch.cache_clear()
 
 
 class PersistentDict(_PersistentDictBase):
diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py
index 73d8f6e6c00c23e416431d05cf24e7c56acd8bef..0989bb06b316584fb04088f87ee1e4f6536fe20b 100644
--- a/pytools/test/test_persistent_dict.py
+++ b/pytools/test/test_persistent_dict.py
@@ -309,6 +309,14 @@ def test_write_once_persistent_dict_lru_policy():
         pdict.fetch(4)
         assert pdict.fetch(1) is not val1
 
+        # test clear_in_mem_cache
+        val1 = pdict.fetch(1)
+        pdict.clear_in_mem_cache()
+        assert pdict.fetch(1) is not val1
+
+        val1 = pdict.fetch(1)
+        assert pdict.fetch(1) is val1
+
     finally:
         shutil.rmtree(tmpdir)