From 7909696ea52ac74f204867bf2a13780e76a8ade7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Thu, 21 Apr 2016 00:22:35 -0500 Subject: [PATCH] Fix, test DiskDict --- pytools/diskdict.py | 15 ++++++++------- test/test_pytools.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/pytools/diskdict.py b/pytools/diskdict.py index d130389..a01ff05 100644 --- a/pytools/diskdict.py +++ b/pytools/diskdict.py @@ -90,9 +90,9 @@ class DiskDict(object): "select key_pickle, version_pickle, result_pickle from data" " where key_hash = ? and version_hash = ?", (hash(key), self.version_hash)): - if loads(str(key_pickle)) == key \ - and loads(str(version_pickle)) == self.version: - result = loads(str(result_pickle)) + if loads(six.binary_type(key_pickle)) == key \ + and loads(six.binary_type(version_pickle)) == self.version: + result = loads(six.binary_type(result_pickle)) self.cache[key] = result return True @@ -107,9 +107,9 @@ class DiskDict(object): "select key_pickle, version_pickle, result_pickle from data" " where key_hash = ? and version_hash = ?", (hash(key), self.version_hash)): - if loads(str(key_pickle)) == key \ - and loads(str(version_pickle)) == self.version: - result = loads(str(result_pickle)) + if loads(six.binary_type(key_pickle)) == key \ + and loads(six.binary_type(version_pickle)) == self.version: + result = loads(six.binary_type(result_pickle)) self.cache[key] = result return result @@ -124,7 +124,8 @@ class DiskDict(object): "select id, key_pickle, version_pickle from data" " where key_hash = ? and version_hash = ?", (hash(key), self.version_hash)): - if loads(key_pickle) == key and loads(version_pickle) == self.version: + if (loads(six.binary_type(key_pickle)) == key + and loads(six.binary_type(version_pickle)) == self.version): self.db_conn.execute("delete from data where id = ?", (item_id,)) self.commit_countdown -= 1 diff --git a/test/test_pytools.py b/test/test_pytools.py index 3773243..6e2a1eb 100644 --- a/test/test_pytools.py +++ b/test/test_pytools.py @@ -154,6 +154,42 @@ def test_spatial_btree(dims, do_plot=False): pt.show() +def test_diskdict(): + from pytools.diskdict import DiskDict + + from tempfile import NamedTemporaryFile + + with NamedTemporaryFile() as ntf: + d = DiskDict(ntf.name) + + key_val = [ + ((), "hi"), + (frozenset([1, 2, "hi"]), 5) + ] + + for k, v in key_val: + d[k] = v + for k, v in key_val: + assert d[k] == v + del d + + d = DiskDict(ntf.name) + for k, v in key_val: + del d[k] + del d + + d = DiskDict(ntf.name) + for k, v in key_val: + d[k] = v + del d + + d = DiskDict(ntf.name) + for k, v in key_val: + assert k in d + assert d[k] == v + del d + + if __name__ == "__main__": # make sure that import failures get reported, instead of skipping the tests. import pyopencl # noqa -- GitLab