From 7909696ea52ac74f204867bf2a13780e76a8ade7 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 21 Apr 2016 00:22:35 -0500
Subject: [PATCH] Fix, test DiskDict

---
 pytools/diskdict.py  | 15 ++++++++-------
 test/test_pytools.py | 36 ++++++++++++++++++++++++++++++++++++
 2 files changed, 44 insertions(+), 7 deletions(-)

diff --git a/pytools/diskdict.py b/pytools/diskdict.py
index d130389..a01ff05 100644
--- a/pytools/diskdict.py
+++ b/pytools/diskdict.py
@@ -90,9 +90,9 @@ class DiskDict(object):
                     "select key_pickle, version_pickle, result_pickle from data"
                     " where key_hash = ? and version_hash = ?",
                     (hash(key), self.version_hash)):
-                if loads(str(key_pickle)) == key \
-                        and loads(str(version_pickle)) == self.version:
-                    result = loads(str(result_pickle))
+                if loads(six.binary_type(key_pickle)) == key \
+                        and loads(six.binary_type(version_pickle)) == self.version:
+                    result = loads(six.binary_type(result_pickle))
                     self.cache[key] = result
                     return True
 
@@ -107,9 +107,9 @@ class DiskDict(object):
                     "select key_pickle, version_pickle, result_pickle from data"
                     " where key_hash = ? and version_hash = ?",
                     (hash(key), self.version_hash)):
-                if loads(str(key_pickle)) == key \
-                        and loads(str(version_pickle)) == self.version:
-                    result = loads(str(result_pickle))
+                if loads(six.binary_type(key_pickle)) == key \
+                        and loads(six.binary_type(version_pickle)) == self.version:
+                    result = loads(six.binary_type(result_pickle))
                     self.cache[key] = result
                     return result
 
@@ -124,7 +124,8 @@ class DiskDict(object):
                 "select id, key_pickle, version_pickle from data"
                 " where key_hash = ? and version_hash = ?",
                 (hash(key), self.version_hash)):
-            if loads(key_pickle) == key and loads(version_pickle) == self.version:
+            if (loads(six.binary_type(key_pickle)) == key
+                    and loads(six.binary_type(version_pickle)) == self.version):
                 self.db_conn.execute("delete from data where id = ?", (item_id,))
 
         self.commit_countdown -= 1
diff --git a/test/test_pytools.py b/test/test_pytools.py
index 3773243..6e2a1eb 100644
--- a/test/test_pytools.py
+++ b/test/test_pytools.py
@@ -154,6 +154,42 @@ def test_spatial_btree(dims, do_plot=False):
         pt.show()
 
 
+def test_diskdict():
+    from pytools.diskdict import DiskDict
+
+    from tempfile import NamedTemporaryFile
+
+    with NamedTemporaryFile() as ntf:
+        d = DiskDict(ntf.name)
+
+        key_val = [
+            ((), "hi"),
+            (frozenset([1, 2, "hi"]), 5)
+            ]
+
+        for k, v in key_val:
+            d[k] = v
+        for k, v in key_val:
+            assert d[k] == v
+        del d
+
+        d = DiskDict(ntf.name)
+        for k, v in key_val:
+            del d[k]
+        del d
+
+        d = DiskDict(ntf.name)
+        for k, v in key_val:
+            d[k] = v
+        del d
+
+        d = DiskDict(ntf.name)
+        for k, v in key_val:
+            assert k in d
+            assert d[k] == v
+        del d
+
+
 if __name__ == "__main__":
     # make sure that import failures get reported, instead of skipping the tests.
     import pyopencl  # noqa
-- 
GitLab