diff --git a/pytools/diskdict.py b/pytools/diskdict.py index d130389660ea366ec2099bce3be652afaf5b19cb..a01ff05cd45988f4891cf55aafa46ffd4dc74bb5 100644 --- a/pytools/diskdict.py +++ b/pytools/diskdict.py @@ -90,9 +90,9 @@ class DiskDict(object): "select key_pickle, version_pickle, result_pickle from data" " where key_hash = ? and version_hash = ?", (hash(key), self.version_hash)): - if loads(str(key_pickle)) == key \ - and loads(str(version_pickle)) == self.version: - result = loads(str(result_pickle)) + if loads(six.binary_type(key_pickle)) == key \ + and loads(six.binary_type(version_pickle)) == self.version: + result = loads(six.binary_type(result_pickle)) self.cache[key] = result return True @@ -107,9 +107,9 @@ class DiskDict(object): "select key_pickle, version_pickle, result_pickle from data" " where key_hash = ? and version_hash = ?", (hash(key), self.version_hash)): - if loads(str(key_pickle)) == key \ - and loads(str(version_pickle)) == self.version: - result = loads(str(result_pickle)) + if loads(six.binary_type(key_pickle)) == key \ + and loads(six.binary_type(version_pickle)) == self.version: + result = loads(six.binary_type(result_pickle)) self.cache[key] = result return result @@ -124,7 +124,8 @@ class DiskDict(object): "select id, key_pickle, version_pickle from data" " where key_hash = ? and version_hash = ?", (hash(key), self.version_hash)): - if loads(key_pickle) == key and loads(version_pickle) == self.version: + if (loads(six.binary_type(key_pickle)) == key + and loads(six.binary_type(version_pickle)) == self.version): self.db_conn.execute("delete from data where id = ?", (item_id,)) self.commit_countdown -= 1 diff --git a/test/test_pytools.py b/test/test_pytools.py index 377324386c0b3f418b325324db2a9a13a7458019..6e2a1ebf9d4908c5329f0bc84ea3b002a2d813ad 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -154,6 +154,42 @@ def test_spatial_btree(dims, do_plot=False): pt.show() +def test_diskdict(): + from pytools.diskdict import DiskDict + + from tempfile import NamedTemporaryFile + + with NamedTemporaryFile() as ntf: + d = DiskDict(ntf.name) + + key_val = [ + ((), "hi"), + (frozenset([1, 2, "hi"]), 5) + ] + + for k, v in key_val: + d[k] = v + for k, v in key_val: + assert d[k] == v + del d + + d = DiskDict(ntf.name) + for k, v in key_val: + del d[k] + del d + + d = DiskDict(ntf.name) + for k, v in key_val: + d[k] = v + del d + + d = DiskDict(ntf.name) + for k, v in key_val: + assert k in d + assert d[k] == v + del d + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests. import pyopencl # noqa