From b1be1516663213447627ed49553afed175e4ab8b Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Thu, 5 Oct 2017 16:15:58 -0500
Subject: [PATCH] Improve PersistentDict warnings. * Add a CacheCollision
 warning subclass. * Track stack level so that warnings show user code line
 numbers.

Example:

Modifying line 287 of test_persistent_dict.py so that the warning
isn't captured by the test, we get this:

```
test_persistent_dict.py:287: CollisionWarning: pytools-test: key collision in cache at '/tmp/tmpq9p3i0b9' -- these are sufficiently unlikely that they're often indicative of a broken implementation of equality comparison
  pdict[key2]  # user code
```

The old behavior was:

```
/home/matt/src/pytools/pytools/persistent_dict.py:466: UserWarning: pytools-test: key collision in cache at '/tmp/tmpuk8js9jw' -- these are sufficiently unlikely that they're often indicative of a broken implementation of equality comparison
  % (self.identifier, self.container_dir))
```
---
 pytools/persistent_dict.py   | 85 ++++++++++++++++++++++++++++--------
 test/test_persistent_dict.py |  8 ++--
 2 files changed, 70 insertions(+), 23 deletions(-)

diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py
index 02b14dc..b0dc9fd 100644
--- a/pytools/persistent_dict.py
+++ b/pytools/persistent_dict.py
@@ -32,6 +32,7 @@ logger = logging.getLogger(__name__)
 
 
 import collections
+import functools
 import six
 import sys
 import os
@@ -48,6 +49,8 @@ valid across interpreter invocations, unlike Python's built-in hashes.
 .. autoexception:: NoSuchEntryError
 .. autoexception:: ReadOnlyEntryError
 
+.. autowarning:: CollisionWarning
+
 .. autoclass:: KeyBuilder
 .. autoclass:: PersistentDict
 .. autoclass:: WriteOncePersistentDict
@@ -78,6 +81,37 @@ def update_checksum(checksum, obj):
         checksum.update(obj)
 
 
+def _tracks_stacklevel(cls, exclude=frozenset(["__init__"])):
+    """Changes all the methods of `cls` to track the call stack level in a member
+    called `_stacklevel`.
+    """
+    def make_wrapper(f):
+        @functools.wraps(f)
+        def wrapper(obj, *args, **kwargs):
+            assert obj._stacklevel >= 0, obj._stacklevel
+            # Increment by 2 because the method is wrapped.
+            obj._stacklevel += 2
+            try:
+                return f(obj, *args, **kwargs)
+            finally:
+                obj._stacklevel -= 2
+
+        return wrapper
+
+    for member in cls.__dict__:
+        f = getattr(cls, member)
+
+        if member in exclude:
+            continue
+
+        if not six.callable(f):
+            continue
+
+        setattr(cls, member, make_wrapper(f))
+
+    return cls
+
+
 # {{{ cleanup managers
 
 class CleanupBase(object):
@@ -101,7 +135,7 @@ class CleanupManager(CleanupBase):
 
 
 class LockManager(CleanupBase):
-    def __init__(self, cleanup_m, lock_file):
+    def __init__(self, cleanup_m, lock_file, _stacklevel=1):
         self.lock_file = lock_file
 
         attempts = 0
@@ -121,7 +155,8 @@ class LockManager(CleanupBase):
             if attempts > 10:
                 from warnings import warn
                 warn("could not obtain lock--delete '%s' if necessary"
-                        % self.lock_file)
+                        % self.lock_file,
+                     stacklevel=1 + _stacklevel)
             if attempts > 3 * 60:
                 raise RuntimeError("waited more than three minutes "
                         "on the lock file '%s'"
@@ -395,8 +430,16 @@ class ReadOnlyEntryError(KeyError):
     pass
 
 
+class CollisionWarning(UserWarning):
+    pass
+
+
+@_tracks_stacklevel
 class _PersistentDictBase(object):
     def __init__(self, identifier, key_builder=None, container_dir=None):
+        # for issuing warnings
+        self._stacklevel = 0
+
         self.identifier = identifier
 
         if key_builder is None:
@@ -417,6 +460,10 @@ class _PersistentDictBase(object):
 
         self._make_container_dir()
 
+    def _warn(self, msg, category=UserWarning):
+        from warnings import warn
+        warn(msg, category, stacklevel=1 + self._stacklevel)
+
     def store_if_not_present(self, key, value):
         self.store(key, value, _skip_if_present=True)
 
@@ -458,12 +505,12 @@ class _PersistentDictBase(object):
     def _collision_check(self, key, stored_key):
         if stored_key != key:
             # Key collision, oh well.
-            from warnings import warn
-            warn("%s: key collision in cache at '%s' -- these are "
+            self._warn("%s: key collision in cache at '%s' -- these are "
                     "sufficiently unlikely that they're often "
                     "indicative of a broken implementation "
                     "of equality comparison"
-                    % (self.identifier, self.container_dir))
+                    % (self.identifier, self.container_dir),
+                 CollisionWarning)
             # This is here so we can debug the equality comparison
             stored_key == key
             raise NoSuchEntryError(key)
@@ -487,6 +534,7 @@ class _PersistentDictBase(object):
         self._make_container_dir()
 
 
+@_tracks_stacklevel
 class WriteOncePersistentDict(_PersistentDictBase):
     def __init__(self, identifier, key_builder=None, container_dir=None,
              in_mem_cache_size=256):
@@ -516,8 +564,7 @@ class WriteOncePersistentDict(_PersistentDictBase):
             attempts += 1
 
             if attempts > 10:
-                from warnings import warn
-                warn("waiting until unlocked--delete '%s' if necessary"
+                self._warn("waiting until unlocked--delete '%s' if necessary"
                         % lock_file)
 
             if attempts > 3 * 60:
@@ -600,8 +647,7 @@ class WriteOncePersistentDict(_PersistentDictBase):
         try:
             read_key = self._read(key_file)
         except:
-            from warnings import warn
-            warn("pytools.persistent_dict.WriteOncePersistentDict(%s) "
+            self._warn("pytools.persistent_dict.WriteOncePersistentDict(%s) "
                     "encountered an invalid "
                     "key file for key %s. Remove the directory "
                     "'%s' if necessary."
@@ -620,7 +666,7 @@ class WriteOncePersistentDict(_PersistentDictBase):
         try:
             read_contents = self._read(contents_file)
         except:
-            warn("pytools.persistent_dict.WriteOncePersistentDict(%s) "
+            self._warn("pytools.persistent_dict.WriteOncePersistentDict(%s) "
                     "encountered an invalid "
                     "key file for key %s. Remove the directory "
                     "'%s' if necessary."
@@ -637,6 +683,7 @@ class WriteOncePersistentDict(_PersistentDictBase):
         self._cache.clear()
 
 
+@_tracks_stacklevel
 class PersistentDict(_PersistentDictBase):
     def __init__(self, identifier, key_builder=None, container_dir=None):
         """
@@ -658,7 +705,8 @@ class PersistentDict(_PersistentDictBase):
         cleanup_m = CleanupManager()
         try:
             try:
-                LockManager(cleanup_m, self._lock_file(hexdigest_key))
+                LockManager(cleanup_m, self._lock_file(hexdigest_key),
+                        1 + self._stacklevel)
                 item_dir_m = ItemDirManager(
                         cleanup_m, self._item_dir(hexdigest_key),
                         delete_on_error=True)
@@ -697,7 +745,8 @@ class PersistentDict(_PersistentDictBase):
         cleanup_m = CleanupManager()
         try:
             try:
-                LockManager(cleanup_m, self._lock_file(hexdigest_key))
+                LockManager(cleanup_m, self._lock_file(hexdigest_key),
+                        1 + self._stacklevel)
                 item_dir_m = ItemDirManager(
                         cleanup_m, item_dir, delete_on_error=False)
 
@@ -710,8 +759,7 @@ class PersistentDict(_PersistentDictBase):
                     read_key = self._read(key_path)
                 except:
                     item_dir_m.reset()
-                    from warnings import warn
-                    warn("pytools.persistent_dict.PersistentDict(%s) "
+                    self._warn("pytools.persistent_dict.PersistentDict(%s) "
                             "encountered an invalid "
                             "key file for key %s. Entry deleted."
                             % (self.identifier, hexdigest_key))
@@ -730,8 +778,7 @@ class PersistentDict(_PersistentDictBase):
                     read_contents = self._read(value_path)
                 except:
                     item_dir_m.reset()
-                    from warnings import warn
-                    warn("pytools.persistent_dict.PersistentDict(%s) "
+                    self._warn("pytools.persistent_dict.PersistentDict(%s) "
                             "encountered an invalid "
                             "key file for key %s. Entry deleted."
                             % (self.identifier, hexdigest_key))
@@ -758,7 +805,8 @@ class PersistentDict(_PersistentDictBase):
         cleanup_m = CleanupManager()
         try:
             try:
-                LockManager(cleanup_m, self._lock_file(hexdigest_key))
+                LockManager(cleanup_m, self._lock_file(hexdigest_key),
+                        1 + self._stacklevel)
                 item_dir_m = ItemDirManager(
                         cleanup_m, item_dir, delete_on_error=False)
                 key_file = self._key_file(hexdigest_key)
@@ -769,8 +817,7 @@ class PersistentDict(_PersistentDictBase):
                     read_key = self._read(key_file)
                 except:
                     item_dir_m.reset()
-                    from warnings import warn
-                    warn("pytools.persistent_dict.PersistentDict(%s) "
+                    self._warn("pytools.persistent_dict.PersistentDict(%s) "
                             "encountered an invalid "
                             "key file for key %s. Entry deleted."
                             % (self.identifier, hexdigest_key))
diff --git a/test/test_persistent_dict.py b/test/test_persistent_dict.py
index 4256a24..ea6665f 100644
--- a/test/test_persistent_dict.py
+++ b/test/test_persistent_dict.py
@@ -10,7 +10,7 @@ from six.moves import zip
 
 from pytools.persistent_dict import (
         PersistentDict, WriteOncePersistentDict, NoSuchEntryError,
-        ReadOnlyEntryError)
+        ReadOnlyEntryError, CollisionWarning)
 
 
 # {{{ type for testing
@@ -155,12 +155,12 @@ def test_persistent_dict_cache_collisions():
         pdict[key1] = 1
 
         # check lookup
-        with pytest.warns(UserWarning):
+        with pytest.warns(CollisionWarning):
             with pytest.raises(NoSuchEntryError):
                 pdict[key2]
 
         # check deletion
-        with pytest.warns(UserWarning):
+        with pytest.warns(CollisionWarning):
             with pytest.raises(NoSuchEntryError):
                 del pdict[key2]
 
@@ -283,7 +283,7 @@ def test_write_once_persistent_dict_cache_collisions():
         pdict[key1] = 1
 
         # check lookup
-        with pytest.warns(UserWarning):
+        with pytest.warns(CollisionWarning):
             with pytest.raises(NoSuchEntryError):
                 pdict[key2]
 
-- 
GitLab