From 6f030709e55fd66a8ab0863d194cec318fedefd5 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Mon, 15 Apr 2024 11:45:35 -0500 Subject: [PATCH] PersistentDict: better docs, type annotations (#212) * better docs, type annotations * more annotations * add to README * make it generic * fix doc * complete annotations --- README.rst | 1 + doc/conf.py | 1 + pytools/persistent_dict.py | 114 +++++++++++++++++++-------- pytools/test/test_persistent_dict.py | 110 ++++++++++++++------------ run-mypy.sh | 2 +- 5 files changed, 142 insertions(+), 86 deletions(-) diff --git a/README.rst b/README.rst index 01afc39..1cbdccf 100644 --- a/README.rst +++ b/README.rst @@ -24,6 +24,7 @@ nonetheless, here's what's on offer: GvR's monkeypatch_xxx() hack, the elusive `flatten`, and much more. * Batch job submission, `pytools.batchjob`. * A lexer, `pytools.lex`. +* A persistent key-value store, `pytools.persistent_dict`. Links: diff --git a/doc/conf.py b/doc/conf.py index 0d4a936..13dcc2c 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -32,6 +32,7 @@ intersphinx_mapping = { "pytest": ("https://docs.pytest.org/en/stable/", None), "setuptools": ("https://setuptools.pypa.io/en/latest/", None), "python": ("https://docs.python.org/3", None), + "platformdirs": ("https://platformdirs.readthedocs.io/en/latest/", None), } nitpicky = True diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index 499ebdc..9515f37 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -37,7 +37,7 @@ import shutil import sys from dataclasses import fields as dc_fields, is_dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Mapping, Protocol +from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, Protocol, TypeVar if TYPE_CHECKING: @@ -75,6 +75,18 @@ This module also provides a disk-backed dictionary that uses persistent hashing. .. autoclass:: KeyBuilder .. autoclass:: PersistentDict .. autoclass:: WriteOncePersistentDict + + +Internal stuff that is only here because the documentation tool wants it +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. class:: K + + A type variable for the key type of a :class:`PersistentDict`. + +.. class:: V + + A type variable for the value type of a :class:`PersistentDict`. """ @@ -478,8 +490,14 @@ class CollisionWarning(UserWarning): pass -class _PersistentDictBase: - def __init__(self, identifier, key_builder=None, container_dir=None): +K = TypeVar("K") +V = TypeVar("V") + + +class _PersistentDictBase(Generic[K, V]): + def __init__(self, identifier: str, + key_builder: Optional[KeyBuilder] = None, + container_dir: Optional[str] = None) -> None: self.identifier = identifier if key_builder is None: @@ -509,32 +527,37 @@ class _PersistentDictBase: self._make_container_dir() @staticmethod - def _warn(msg, category=UserWarning, stacklevel=0): + def _warn(msg: str, category: Any = UserWarning, stacklevel: int = 0) -> None: from warnings import warn warn(msg, category, stacklevel=1 + stacklevel) - def store_if_not_present(self, key, value, _stacklevel=0): + def store_if_not_present(self, key: K, value: V, + _stacklevel: int = 0) -> None: + """Store (*key*, *value*) if *key* is not already present.""" self.store(key, value, _skip_if_present=True, _stacklevel=1 + _stacklevel) - def store(self, key, value, _skip_if_present=False, _stacklevel=0): + def store(self, key: K, value: V, _skip_if_present: bool = False, + _stacklevel: int = 0) -> None: + """Store (*key*, *value*) in the dictionary.""" raise NotImplementedError() - def fetch(self, key, _stacklevel=0): + def fetch(self, key: K, _stacklevel: int = 0) -> V: + """Return the value associated with *key* in the dictionary.""" raise NotImplementedError() @staticmethod - def _read(path): + def _read(path: str) -> V: from pickle import load with open(path, "rb") as inf: return load(inf) @staticmethod - def _write(path, value): + def _write(path: str, value: V) -> None: from pickle import HIGHEST_PROTOCOL, dump with open(path, "wb") as outf: dump(value, outf, protocol=HIGHEST_PROTOCOL) - def _item_dir(self, hexdigest_key): + def _item_dir(self, hexdigest_key: str) -> str: from os.path import join # Some file systems limit the number of directories in a directory. @@ -546,22 +569,23 @@ class _PersistentDictBase: hexdigest_key[3:6], hexdigest_key[6:]) - def _key_file(self, hexdigest_key): + def _key_file(self, hexdigest_key: str) -> str: from os.path import join return join(self._item_dir(hexdigest_key), "key") - def _contents_file(self, hexdigest_key): + def _contents_file(self, hexdigest_key: str) -> str: from os.path import join return join(self._item_dir(hexdigest_key), "contents") - def _lock_file(self, hexdigest_key): + def _lock_file(self, hexdigest_key: str) -> str: from os.path import join return join(self.container_dir, str(hexdigest_key) + ".lock") - def _make_container_dir(self): + def _make_container_dir(self) -> None: + """Create the container directory to store the dictionary.""" os.makedirs(self.container_dir, exist_ok=True) - def _collision_check(self, key, stored_key, _stacklevel): + def _collision_check(self, key: K, stored_key: K, _stacklevel: int) -> None: if stored_key != key: # Key collision, oh well. self._warn(f"{self.identifier}: key collision in cache at " @@ -577,13 +601,16 @@ class _PersistentDictBase: stored_key == key # pylint:disable=pointless-statement # noqa: B015 raise NoSuchEntryCollisionError(key) - def __getitem__(self, key): + def __getitem__(self, key: K) -> V: + """Return the value associated with *key* in the dictionary.""" return self.fetch(key, _stacklevel=1) - def __setitem__(self, key, value): + def __setitem__(self, key: K, value: V) -> None: + """Store (*key*, *value*) in the dictionary.""" self.store(key, value, _stacklevel=1) - def clear(self): + def clear(self) -> None: + """Remove all entries from the dictionary.""" try: shutil.rmtree(self.container_dir) except OSError as e: @@ -593,8 +620,9 @@ class _PersistentDictBase: self._make_container_dir() -class WriteOncePersistentDict(_PersistentDictBase): - """A concurrent disk-backed dictionary that disallows overwriting/deletion. +class WriteOncePersistentDict(_PersistentDictBase[K, V]): + """A concurrent disk-backed dictionary that disallows overwriting/ + deletion (but allows removing all entries). Compared with :class:`PersistentDict`, this class has faster retrieval times because it uses an LRU cache to cache entries in memory. @@ -608,14 +636,19 @@ class WriteOncePersistentDict(_PersistentDictBase): .. automethod:: store_if_not_present .. automethod:: fetch """ - def __init__(self, identifier, key_builder=None, container_dir=None, - in_mem_cache_size=256): + def __init__(self, identifier: str, + key_builder: Optional[KeyBuilder] = None, + container_dir: Optional[str] = None, + in_mem_cache_size: int = 256) -> None: """ :arg identifier: a file-name-compatible string identifying this dictionary :arg key_builder: a subclass of :class:`KeyBuilder` + :arg container_dir: the directory in which to store this + dictionary. If ``None``, the default cache directory from + :func:`platformdirs.user_cache_dir` is used :arg in_mem_cache_size: retain an in-memory cache of up to - *in_mem_cache_size* items + *in_mem_cache_size* items (with an LRU replacement policy) """ _PersistentDictBase.__init__(self, identifier, key_builder, container_dir) self._in_mem_cache_size = in_mem_cache_size @@ -624,12 +657,14 @@ class WriteOncePersistentDict(_PersistentDictBase): def clear_in_mem_cache(self) -> None: """ + Clear the in-memory cache of this dictionary. + .. versionadded:: 2023.1.1 """ self._fetch.cache_clear() - def _spin_until_removed(self, lock_file, stacklevel): + def _spin_until_removed(self, lock_file: str, stacklevel: int) -> None: from os.path import exists attempts = 0 @@ -649,7 +684,8 @@ class WriteOncePersistentDict(_PersistentDictBase): f"on the lock file '{lock_file}'" "--something is wrong") - def store(self, key, value, _skip_if_present=False, _stacklevel=0): + def store(self, key: K, value: V, _skip_if_present: bool = False, + _stacklevel: int = 0) -> None: hexdigest_key = self.key_builder(key) cleanup_m = CleanupManager() @@ -682,7 +718,7 @@ class WriteOncePersistentDict(_PersistentDictBase): finally: cleanup_m.clean_up() - def fetch(self, key, _stacklevel=0): + def fetch(self, key: K, _stacklevel: int = 0) -> Any: hexdigest_key = self.key_builder(key) (stored_key, stored_value) = self._fetch(hexdigest_key, 1 + _stacklevel) @@ -691,7 +727,8 @@ class WriteOncePersistentDict(_PersistentDictBase): return stored_value - def _fetch(self, hexdigest_key, _stacklevel=0): # pylint:disable=method-hidden + def _fetch(self, hexdigest_key: str, # pylint:disable=method-hidden + _stacklevel: int = 0) -> V: # This is separate from fetch() to allow for LRU caching # {{{ check path exists and is unlocked @@ -748,12 +785,12 @@ class WriteOncePersistentDict(_PersistentDictBase): return (read_key, read_contents) - def clear(self): + def clear(self) -> None: _PersistentDictBase.clear(self) self._fetch.cache_clear() -class PersistentDict(_PersistentDictBase): +class PersistentDict(_PersistentDictBase[K, V]): """A concurrent disk-backed dictionary. .. automethod:: __init__ @@ -766,15 +803,22 @@ class PersistentDict(_PersistentDictBase): .. automethod:: fetch .. automethod:: remove """ - def __init__(self, identifier, key_builder=None, container_dir=None): + def __init__(self, + identifier: str, + key_builder: Optional[KeyBuilder] = None, + container_dir: Optional[str] = None) -> None: """ :arg identifier: a file-name-compatible string identifying this dictionary :arg key_builder: a subclass of :class:`KeyBuilder` + :arg container_dir: the directory in which to store this + dictionary. If ``None``, the default cache directory from + :func:`platformdirs.user_cache_dir` is used """ _PersistentDictBase.__init__(self, identifier, key_builder, container_dir) - def store(self, key, value, _skip_if_present=False, _stacklevel=0): + def store(self, key: K, value: V, _skip_if_present: bool = False, + _stacklevel: int = 0) -> None: hexdigest_key = self.key_builder(key) cleanup_m = CleanupManager() @@ -807,7 +851,7 @@ class PersistentDict(_PersistentDictBase): finally: cleanup_m.clean_up() - def fetch(self, key, _stacklevel=0): + def fetch(self, key: K, _stacklevel: int = 0) -> V: hexdigest_key = self.key_builder(key) item_dir = self._item_dir(hexdigest_key) @@ -871,7 +915,8 @@ class PersistentDict(_PersistentDictBase): finally: cleanup_m.clean_up() - def remove(self, key, _stacklevel=0): + def remove(self, key: K, _stacklevel: int = 0) -> None: + """Remove the entry associated with *key* from the dictionary.""" hexdigest_key = self.key_builder(key) item_dir = self._item_dir(hexdigest_key) @@ -913,7 +958,8 @@ class PersistentDict(_PersistentDictBase): finally: cleanup_m.clean_up() - def __delitem__(self, key): + def __delitem__(self, key: K) -> None: + """Remove the entry associated with *key* from the dictionary.""" self.remove(key, _stacklevel=1) # }}} diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index 0989bb0..1bebf61 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -3,6 +3,7 @@ import sys # noqa import tempfile from dataclasses import dataclass from enum import Enum, IntEnum +from typing import Any, Dict import pytest @@ -16,30 +17,25 @@ from pytools.tag import Tag, tag_dataclass class PDictTestingKeyOrValue: - def __init__(self, val, hash_key=None): + def __init__(self, val: Any, hash_key=None) -> None: self.val = val if hash_key is None: hash_key = val self.hash_key = hash_key - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return {"val": self.val, "hash_key": self.hash_key} - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self.val == other.val - def __ne__(self, other): - return not self.__eq__(other) - - def update_persistent_hash(self, key_hash, key_builder): + def update_persistent_hash(self, key_hash: Any, key_builder: KeyBuilder) -> None: key_builder.rec(key_hash, self.hash_key) - def __repr__(self): + def __repr__(self) -> str: return "PDictTestingKeyOrValue(val={!r},hash_key={!r})".format( self.val, self.hash_key) - __str__ = __repr__ - # }}} @@ -64,10 +60,11 @@ class MyStruct: value: int -def test_persistent_dict_storage_and_lookup(): +def test_persistent_dict_storage_and_lookup() -> None: try: tmpdir = tempfile.mkdtemp() - pdict = PersistentDict("pytools-test", container_dir=tmpdir) + pdict: PersistentDict[Any, int] = PersistentDict("pytools-test", + container_dir=tmpdir) from random import randrange @@ -77,7 +74,8 @@ def test_persistent_dict_storage_and_lookup(): for i in range(n)) keys = [ - (randrange(2000)-1000, rand_str(), None, SomeTag(rand_str()), + (randrange(2000)-1000, rand_str(), None, + SomeTag(rand_str()), # type: ignore[call-arg] frozenset({"abc", 123})) for i in range(20)] values = [randrange(2000) for i in range(20)] @@ -161,10 +159,11 @@ def test_persistent_dict_storage_and_lookup(): shutil.rmtree(tmpdir) -def test_persistent_dict_deletion(): +def test_persistent_dict_deletion() -> None: try: tmpdir = tempfile.mkdtemp() - pdict = PersistentDict("pytools-test", container_dir=tmpdir) + pdict: PersistentDict[int, int] = PersistentDict("pytools-test", + container_dir=tmpdir) pdict[0] = 0 del pdict[0] @@ -179,11 +178,13 @@ def test_persistent_dict_deletion(): shutil.rmtree(tmpdir) -def test_persistent_dict_synchronization(): +def test_persistent_dict_synchronization() -> None: try: tmpdir = tempfile.mkdtemp() - pdict1 = PersistentDict("pytools-test", container_dir=tmpdir) - pdict2 = PersistentDict("pytools-test", container_dir=tmpdir) + pdict1: PersistentDict[int, int] = PersistentDict("pytools-test", + container_dir=tmpdir) + pdict2: PersistentDict[int, int] = PersistentDict("pytools-test", + container_dir=tmpdir) # check lookup pdict1[0] = 1 @@ -202,10 +203,11 @@ def test_persistent_dict_synchronization(): shutil.rmtree(tmpdir) -def test_persistent_dict_cache_collisions(): +def test_persistent_dict_cache_collisions() -> None: try: tmpdir = tempfile.mkdtemp() - pdict = PersistentDict("pytools-test", container_dir=tmpdir) + pdict: PersistentDict[PDictTestingKeyOrValue, int] = \ + PersistentDict("pytools-test", container_dir=tmpdir) key1 = PDictTestingKeyOrValue(1, hash_key=0) key2 = PDictTestingKeyOrValue(2, hash_key=0) @@ -233,10 +235,11 @@ def test_persistent_dict_cache_collisions(): shutil.rmtree(tmpdir) -def test_persistent_dict_clear(): +def test_persistent_dict_clear() -> None: try: tmpdir = tempfile.mkdtemp() - pdict = PersistentDict("pytools-test", container_dir=tmpdir) + pdict: PersistentDict[int, int] = PersistentDict("pytools-test", + container_dir=tmpdir) pdict[0] = 1 pdict.fetch(0) @@ -250,10 +253,10 @@ def test_persistent_dict_clear(): @pytest.mark.parametrize("in_mem_cache_size", (0, 256)) -def test_write_once_persistent_dict_storage_and_lookup(in_mem_cache_size): +def test_write_once_persistent_dict_storage_and_lookup(in_mem_cache_size) -> None: try: tmpdir = tempfile.mkdtemp() - pdict = WriteOncePersistentDict( + pdict: WriteOncePersistentDict[int, int] = WriteOncePersistentDict( "pytools-test", container_dir=tmpdir, in_mem_cache_size=in_mem_cache_size) @@ -281,10 +284,10 @@ def test_write_once_persistent_dict_storage_and_lookup(in_mem_cache_size): shutil.rmtree(tmpdir) -def test_write_once_persistent_dict_lru_policy(): +def test_write_once_persistent_dict_lru_policy() -> None: try: tmpdir = tempfile.mkdtemp() - pdict = WriteOncePersistentDict( + pdict: WriteOncePersistentDict[Any, Any] = WriteOncePersistentDict( "pytools-test", container_dir=tmpdir, in_mem_cache_size=3) pdict[1] = PDictTestingKeyOrValue(1) @@ -321,11 +324,13 @@ def test_write_once_persistent_dict_lru_policy(): shutil.rmtree(tmpdir) -def test_write_once_persistent_dict_synchronization(): +def test_write_once_persistent_dict_synchronization() -> None: try: tmpdir = tempfile.mkdtemp() - pdict1 = WriteOncePersistentDict("pytools-test", container_dir=tmpdir) - pdict2 = WriteOncePersistentDict("pytools-test", container_dir=tmpdir) + pdict1: WriteOncePersistentDict[int, int] = \ + WriteOncePersistentDict("pytools-test", container_dir=tmpdir) + pdict2: WriteOncePersistentDict[int, int] = \ + WriteOncePersistentDict("pytools-test", container_dir=tmpdir) # check lookup pdict1[1] = 0 @@ -339,10 +344,11 @@ def test_write_once_persistent_dict_synchronization(): shutil.rmtree(tmpdir) -def test_write_once_persistent_dict_cache_collisions(): +def test_write_once_persistent_dict_cache_collisions() -> None: try: tmpdir = tempfile.mkdtemp() - pdict = WriteOncePersistentDict("pytools-test", container_dir=tmpdir) + pdict: WriteOncePersistentDict[Any, int] = \ + WriteOncePersistentDict("pytools-test", container_dir=tmpdir) key1 = PDictTestingKeyOrValue(1, hash_key=0) key2 = PDictTestingKeyOrValue(2, hash_key=0) @@ -365,10 +371,11 @@ def test_write_once_persistent_dict_cache_collisions(): shutil.rmtree(tmpdir) -def test_write_once_persistent_dict_clear(): +def test_write_once_persistent_dict_clear() -> None: try: tmpdir = tempfile.mkdtemp() - pdict = WriteOncePersistentDict("pytools-test", container_dir=tmpdir) + pdict: WriteOncePersistentDict[int, int] = \ + WriteOncePersistentDict("pytools-test", container_dir=tmpdir) pdict[0] = 1 pdict.fetch(0) @@ -380,7 +387,7 @@ def test_write_once_persistent_dict_clear(): shutil.rmtree(tmpdir) -def test_dtype_hashing(): +def test_dtype_hashing() -> None: np = pytest.importorskip("numpy") keyb = KeyBuilder() @@ -388,7 +395,7 @@ def test_dtype_hashing(): assert keyb(np.dtype(np.float32)) == keyb(np.dtype(np.float32)) -def test_scalar_hashing(): +def test_scalar_hashing() -> None: keyb = KeyBuilder() assert keyb(1) == keyb(1) @@ -429,7 +436,7 @@ def test_scalar_hashing(): "constantdict", ("immutables", "Map"), ("pyrsistent", "pmap"))) -def test_dict_hashing(dict_impl): +def test_dict_hashing(dict_impl) -> None: if isinstance(dict_impl, str): dict_package = dict_impl dict_class = dict_impl @@ -450,7 +457,7 @@ def test_dict_hashing(dict_impl): assert keyb(dc(d)) == keyb(dc({"b": 2, "a": 1})) -def test_frozenset_hashing(): +def test_frozenset_hashing() -> None: keyb = KeyBuilder() assert keyb(frozenset([1, 2, 3])) == keyb(frozenset([1, 2, 3])) @@ -458,7 +465,7 @@ def test_frozenset_hashing(): assert keyb(frozenset([1, 2, 3])) == keyb(frozenset([3, 2, 1])) -def test_frozenorderedset_hashing(): +def test_frozenorderedset_hashing() -> None: pytest.importorskip("orderedsets") from orderedsets import FrozenOrderedSet keyb = KeyBuilder() @@ -470,7 +477,7 @@ def test_frozenorderedset_hashing(): assert keyb(FrozenOrderedSet([1, 2, 3])) == keyb(FrozenOrderedSet([3, 2, 1])) -def test_ABC_hashing(): # noqa: N802 +def test_ABC_hashing() -> None: # noqa: N802 from abc import ABC, ABCMeta keyb = KeyBuilder() @@ -500,7 +507,7 @@ def test_ABC_hashing(): # noqa: N802 assert keyb(MyABC3) != keyb(MyABC) != keyb(MyABC3()) -def test_class_hashing(): +def test_class_hashing() -> None: keyb = KeyBuilder() class WithUpdateMethod: @@ -531,11 +538,11 @@ def test_class_hashing(): class TagClass3(Tag): s: str - assert keyb(TagClass3("foo")) == \ - "c6521f4157ed530d04e956b7046db85e038c120b047cd1b848340d81f9fd8b4a" + assert (keyb(TagClass3("foo")) # type: ignore[call-arg] + == "c6521f4157ed530d04e956b7046db85e038c120b047cd1b848340d81f9fd8b4a") -def test_dataclass_hashing(): +def test_dataclass_hashing() -> None: keyb = KeyBuilder() @dataclass @@ -558,7 +565,7 @@ def test_dataclass_hashing(): assert keyb(MyDC2("hi", 1)) != keyb(MyDC("hi", 1)) -def test_attrs_hashing(): +def test_attrs_hashing() -> None: attrs = pytest.importorskip("attrs") keyb = KeyBuilder() @@ -568,18 +575,18 @@ def test_attrs_hashing(): name: str value: int - assert keyb(MyAttrs("hi", 1)) == \ - "17f272d114d22c1dc0117354777f2d506b303d90e10840d39fb0eef007252f68" + assert (keyb(MyAttrs("hi", 1)) # type: ignore[call-arg] + == "17f272d114d22c1dc0117354777f2d506b303d90e10840d39fb0eef007252f68") - assert keyb(MyAttrs("hi", 1)) == keyb(MyAttrs("hi", 1)) - assert keyb(MyAttrs("hi", 1)) != keyb(MyAttrs("hi", 2)) + assert keyb(MyAttrs("hi", 1)) == keyb(MyAttrs("hi", 1)) # type: ignore[call-arg] + assert keyb(MyAttrs("hi", 1)) != keyb(MyAttrs("hi", 2)) # type: ignore[call-arg] @dataclass class MyDC: name: str value: int - assert keyb(MyDC("hi", 1)) != keyb(MyAttrs("hi", 1)) + assert keyb(MyDC("hi", 1)) != keyb(MyAttrs("hi", 1)) # type: ignore[call-arg] @attrs.define class MyAttrs2: @@ -587,10 +594,11 @@ def test_attrs_hashing(): value: int # Class types must be encoded in hash - assert keyb(MyAttrs2("hi", 1)) != keyb(MyAttrs("hi", 1)) + assert (keyb(MyAttrs2("hi", 1)) # type: ignore[call-arg] + != keyb(MyAttrs("hi", 1))) # type: ignore[call-arg] -def test_xdg_cache_home(): +def test_xdg_cache_home() -> None: import os xdg_dir = "tmpdir_pytools_xdg_test" diff --git a/run-mypy.sh b/run-mypy.sh index 39055a8..715f971 100755 --- a/run-mypy.sh +++ b/run-mypy.sh @@ -4,4 +4,4 @@ set -ex mypy --show-error-codes pytools -mypy --strict --follow-imports=skip pytools/datatable.py +mypy --strict --follow-imports=skip pytools/datatable.py pytools/persistent_dict.py -- GitLab