diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 857299f17798ed51d34f4374ac2950a3ac4721ee..aaca40fabb0f50f7b8ee6bc7ed873b2b977e6d48 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -29,7 +29,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import collections.abc as abc import errno import hashlib import logging @@ -443,125 +442,6 @@ class KeyBuilder: # }}} -# {{{ lru cache - -class _LinkedList: - """The list operates on nodes of the form [value, leftptr, rightpr]. To create a - node of this form you can use `LinkedList.new_node().` - - Supports inserting at the left and deleting from an arbitrary location. - """ - def __init__(self): - self.count = 0 - self.head = None - self.end = None - - @staticmethod - def new_node(element): - return [element, None, None] - - def __len__(self): - return self.count - - def appendleft_node(self, node): - self.count += 1 - - if self.head is None: - self.head = self.end = node - return - - self.head[1] = node - node[2] = self.head - - self.head = node - - def pop_node(self): - end = self.end - self.remove_node(end) - return end - - def remove_node(self, node): - self.count -= 1 - - if self.head is self.end: - assert node is self.head - self.head = self.end = None - return - - left = node[1] - right = node[2] - - if left is None: - self.head = right - else: - left[2] = right - - if right is None: - self.end = left - else: - right[1] = left - - node[1] = node[2] = None - - -class _LRUCache(abc.MutableMapping): - """A mapping that keeps at most *maxsize* items with an LRU replacement policy. - """ - def __init__(self, maxsize): - self.lru_order = _LinkedList() - self.maxsize = maxsize - self.cache = {} - - def __delitem__(self, item): - node = self.cache[item] - self.lru_order.remove_node(node) - del self.cache[item] - - def __getitem__(self, item): - node = self.cache[item] - self.lru_order.remove_node(node) - self.lru_order.appendleft_node(node) - # A linked list node contains a tuple of the form (item, value). - return node[0][1] - - def __contains__(self, item): - return item in self.cache - - def __iter__(self): - return iter(self.cache) - - def __len__(self) -> int: - return len(self.cache) - - def clear(self): - self.cache.clear() - self.lru_order = _LinkedList() - - def __setitem__(self, item, value): - if self.maxsize < 1: - return - - try: - node = self.cache[item] - self.lru_order.remove_node(node) - except KeyError: - if len(self.lru_order) >= self.maxsize: - # Make room for new elements. - end_node = self.lru_order.pop_node() - del self.cache[end_node[0][0]] - - node = self.lru_order.new_node((item, value)) - self.cache[item] = node - - self.lru_order.appendleft_node(node) - - assert len(self.cache) == len(self.lru_order), \ - (len(self.cache), len(self.lru_order)) - assert len(self.lru_order) <= self.maxsize - -# }}} - - # {{{ top-level class NoSuchEntryError(KeyError): @@ -720,7 +600,7 @@ class WriteOncePersistentDict(_PersistentDictBase): """A concurrent disk-backed dictionary that disallows overwriting/deletion. Compared with :class:`PersistentDict`, this class has faster - retrieval times. + retrieval times because it uses an LRU cache to cache entries in memory. .. automethod:: __init__ .. automethod:: __getitem__ @@ -742,14 +622,15 @@ class WriteOncePersistentDict(_PersistentDictBase): """ _PersistentDictBase.__init__(self, identifier, key_builder, container_dir) self._in_mem_cache_size = in_mem_cache_size - self.clear_in_mem_cache() + from functools import lru_cache + self._fetch = lru_cache(maxsize=in_mem_cache_size)(self._fetch) def clear_in_mem_cache(self) -> None: """ .. versionadded:: 2023.1.1 """ - self._cache = _LRUCache(self._in_mem_cache_size) + self._fetch.cache_clear() def _spin_until_removed(self, lock_file, stacklevel): from os.path import exists @@ -807,19 +688,14 @@ class WriteOncePersistentDict(_PersistentDictBase): def fetch(self, key, _stacklevel=0): hexdigest_key = self.key_builder(key) - # {{{ in memory cache + (stored_key, stored_value) = self._fetch(hexdigest_key, 1 + _stacklevel) - try: - stored_key, stored_value = self._cache[hexdigest_key] - except KeyError: - pass - else: - logger.debug("%s: in mem cache hit [key=%s]", - self.identifier, hexdigest_key) - self._collision_check(key, stored_key, 1 + _stacklevel) - return stored_value + self._collision_check(key, stored_key, 1 + _stacklevel) - # }}} + return stored_value + + def _fetch(self, hexdigest_key, _stacklevel=0): # pylint:disable=method-hidden + # This is separate from fetch() to allow for LRU caching # {{{ check path exists and is unlocked @@ -829,7 +705,7 @@ class WriteOncePersistentDict(_PersistentDictBase): if not isdir(item_dir): logger.debug("%s: disk cache miss [key=%s]", self.identifier, hexdigest_key) - raise NoSuchEntryError(key) + raise NoSuchEntryError(hexdigest_key) lock_file = self._lock_file(hexdigest_key) self._spin_until_removed(lock_file, 1 + _stacklevel) @@ -852,9 +728,7 @@ class WriteOncePersistentDict(_PersistentDictBase): f"Remove the directory '{item_dir}' if necessary. " f"(caught: {type(e).__name__}: {e})", stacklevel=1 + _stacklevel) - raise NoSuchEntryInvalidKeyError(key) - - self._collision_check(key, read_key, 1 + _stacklevel) + raise NoSuchEntryInvalidKeyError(hexdigest_key) # }}} @@ -871,16 +745,15 @@ class WriteOncePersistentDict(_PersistentDictBase): f"Remove the directory '{item_dir}' if necessary." f"(caught: {type(e).__name__}: {e})", stacklevel=1 + _stacklevel) - raise NoSuchEntryInvalidContentsError(key) + raise NoSuchEntryInvalidContentsError(hexdigest_key) # }}} - self._cache[hexdigest_key] = (key, read_contents) - return read_contents + return (read_key, read_contents) def clear(self): _PersistentDictBase.clear(self) - self._cache.clear() + self._fetch.cache_clear() class PersistentDict(_PersistentDictBase): diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index 73d8f6e6c00c23e416431d05cf24e7c56acd8bef..0989bb06b316584fb04088f87ee1e4f6536fe20b 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -309,6 +309,14 @@ def test_write_once_persistent_dict_lru_policy(): pdict.fetch(4) assert pdict.fetch(1) is not val1 + # test clear_in_mem_cache + val1 = pdict.fetch(1) + pdict.clear_in_mem_cache() + assert pdict.fetch(1) is not val1 + + val1 = pdict.fetch(1) + assert pdict.fetch(1) is val1 + finally: shutil.rmtree(tmpdir)