diff --git a/pytools/diskdict.py b/pytools/diskdict.py
index d130389660ea366ec2099bce3be652afaf5b19cb..a01ff05cd45988f4891cf55aafa46ffd4dc74bb5 100644
--- a/pytools/diskdict.py
+++ b/pytools/diskdict.py
@@ -90,9 +90,9 @@ class DiskDict(object):
                     "select key_pickle, version_pickle, result_pickle from data"
                     " where key_hash = ? and version_hash = ?",
                     (hash(key), self.version_hash)):
-                if loads(str(key_pickle)) == key \
-                        and loads(str(version_pickle)) == self.version:
-                    result = loads(str(result_pickle))
+                if loads(six.binary_type(key_pickle)) == key \
+                        and loads(six.binary_type(version_pickle)) == self.version:
+                    result = loads(six.binary_type(result_pickle))
                     self.cache[key] = result
                     return True
 
@@ -107,9 +107,9 @@ class DiskDict(object):
                     "select key_pickle, version_pickle, result_pickle from data"
                     " where key_hash = ? and version_hash = ?",
                     (hash(key), self.version_hash)):
-                if loads(str(key_pickle)) == key \
-                        and loads(str(version_pickle)) == self.version:
-                    result = loads(str(result_pickle))
+                if loads(six.binary_type(key_pickle)) == key \
+                        and loads(six.binary_type(version_pickle)) == self.version:
+                    result = loads(six.binary_type(result_pickle))
                     self.cache[key] = result
                     return result
 
@@ -124,7 +124,8 @@ class DiskDict(object):
                 "select id, key_pickle, version_pickle from data"
                 " where key_hash = ? and version_hash = ?",
                 (hash(key), self.version_hash)):
-            if loads(key_pickle) == key and loads(version_pickle) == self.version:
+            if (loads(six.binary_type(key_pickle)) == key
+                    and loads(six.binary_type(version_pickle)) == self.version):
                 self.db_conn.execute("delete from data where id = ?", (item_id,))
 
         self.commit_countdown -= 1
diff --git a/test/test_pytools.py b/test/test_pytools.py
index 377324386c0b3f418b325324db2a9a13a7458019..6e2a1ebf9d4908c5329f0bc84ea3b002a2d813ad 100644
--- a/test/test_pytools.py
+++ b/test/test_pytools.py
@@ -154,6 +154,42 @@ def test_spatial_btree(dims, do_plot=False):
         pt.show()
 
 
+def test_diskdict():
+    from pytools.diskdict import DiskDict
+
+    from tempfile import NamedTemporaryFile
+
+    with NamedTemporaryFile() as ntf:
+        d = DiskDict(ntf.name)
+
+        key_val = [
+            ((), "hi"),
+            (frozenset([1, 2, "hi"]), 5)
+            ]
+
+        for k, v in key_val:
+            d[k] = v
+        for k, v in key_val:
+            assert d[k] == v
+        del d
+
+        d = DiskDict(ntf.name)
+        for k, v in key_val:
+            del d[k]
+        del d
+
+        d = DiskDict(ntf.name)
+        for k, v in key_val:
+            d[k] = v
+        del d
+
+        d = DiskDict(ntf.name)
+        for k, v in key_val:
+            assert k in d
+            assert d[k] == v
+        del d
+
+
 if __name__ == "__main__":
     # make sure that import failures get reported, instead of skipping the tests.
     import pyopencl  # noqa