From 90b3c2c184738e7723b70a96c583ec7120a19222 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Sun, 27 Sep 2009 00:08:29 -0400
Subject: [PATCH] LogManager: Don't keep data in memory. (d'oh)

---
 pytools/log.py | 189 ++++++++++++++++++++++---------------------------
 1 file changed, 83 insertions(+), 106 deletions(-)

diff --git a/pytools/log.py b/pytools/log.py
index dc85e2b..a02b926 100644
--- a/pytools/log.py
+++ b/pytools/log.py
@@ -247,7 +247,7 @@ class LogManager(object):
         assert mode in ["w", "r"], "invalid mode"
 
         self.quantity_data = {}
-        self.quantity_table = {}
+        self.last_values = {}
         self.gather_descriptors = []
         self.tick_count = 0
 
@@ -276,53 +276,52 @@ class LogManager(object):
         self.have_nonlocal_watches = False
 
         # database binding
-        if filename is not None:
+        try:
+            import sqlite3 as sqlite
+        except ImportError:
             try:
-                import sqlite3 as sqlite
+                from pysqlite2 import dbapi2 as sqlite
             except ImportError:
-                try:
-                    from pysqlite2 import dbapi2 as sqlite
-                except ImportError:
-                    raise ImportError, "could not find a usable version of sqlite."
+                raise ImportError, "could not find a usable version of sqlite."
 
+        if filename is None:
+            filename = ":memory:"
+        else:
             if self.is_parallel:
                 filename += "-rank%d" % self.rank
 
-            self.db_conn = sqlite.connect(filename, timeout=30)
-            self.mode = mode
-            try:
-                self.db_conn.execute("select * from quantities;")
-            except sqlite.OperationalError:
-                # we're building a new database
-                if mode == "r":
-                    raise RuntimeError, "Log database '%s' not found" % filename
-
-                self.schema_version = _set_up_schema(self.db_conn)
-                self.set_constant("schema_version", self.schema_version)
-
-                self.set_constant("is_parallel", self.is_parallel)
-
-                # set globally unique run_id
-                if self.is_parallel:
-                    from boostmpi import broadcast
-                    self.set_constant("unique_run_id",
-                            broadcast(self.mpi_comm, _get_unique_id(),
-                                root=self.head_rank))
-                else:
-                    self.set_constant("unique_run_id", _get_unique_id())
+        self.db_conn = sqlite.connect(filename, timeout=30)
+        self.mode = mode
+        try:
+            self.db_conn.execute("select * from quantities;")
+        except sqlite.OperationalError:
+            # we're building a new database
+            if mode == "r":
+                raise RuntimeError, "Log database '%s' not found" % filename
 
-                if self.is_parallel:
-                    self.set_constant("rank_count", self.mpi_comm.size)
-                else:
-                    self.set_constant("rank_count", 1)
+            self.schema_version = _set_up_schema(self.db_conn)
+            self.set_constant("schema_version", self.schema_version)
+
+            self.set_constant("is_parallel", self.is_parallel)
+
+            # set globally unique run_id
+            if self.is_parallel:
+                from boostmpi import broadcast
+                self.set_constant("unique_run_id",
+                        broadcast(self.mpi_comm, _get_unique_id(),
+                            root=self.head_rank))
             else:
-                # we've opened an existing database
-                if mode == "w":
-                    raise RuntimeError, "Log database '%s' already exists" % filename
-                self._load()
-        else:
-            self.db_conn = None
+                self.set_constant("unique_run_id", _get_unique_id())
 
+            if self.is_parallel:
+                self.set_constant("rank_count", self.mpi_comm.size)
+            else:
+                self.set_constant("rank_count", 1)
+        else:
+            # we've opened an existing database
+            if mode == "w":
+                raise RuntimeError, "Log database '%s' already exists" % filename
+            self._load()
 
         self.old_showwarning = None
         if capture_warnings:
@@ -336,9 +335,7 @@ class LogManager(object):
                 # cater to Python 2.5 and earlier
                 self.old_showwarning(message, category, filename, lineno)
 
-            if (self.db_conn is not None
-                    and self.schema_version >= 1
-                    and self.mode == "w"):
+            if self.schema_version >= 1 and self.mode == "w":
                 if self.schema_version >= 2:
                     self.db_conn.execute("insert into warnings values (?,?,?,?,?,?)",
                             (self.rank, self.tick_count, str(message), str(category),
@@ -385,40 +382,33 @@ class LogManager(object):
             self.capture_warnings(False)
 
         self.save()
-
-        if self.db_conn is not None:
-            self.db_conn.close()
+        self.db_conn.close()
 
 
     def get_table(self, q_name):
         if q_name not in self.quantity_data:
             raise KeyError, "invalid quantity name '%s'" % q_name
 
-        try:
-            return self.quantity_table[q_name]
-        except KeyError:
-            from pytools.datatable import DataTable
-            result = self.quantity_table[q_name] = DataTable(["step", "rank", "value"])
+        from pytools.datatable import DataTable
+        result = DataTable(["step", "rank", "value"])
 
-            if self.db_conn is not None:
-                for row in self.db_conn.execute(
-                        "select step, rank, value from %s" % q_name):
-                    result.insert_row(row)
+        for row in self.db_conn.execute(
+                "select step, rank, value from %s" % q_name):
+            result.insert_row(row)
 
-            return result
+        return result
 
     def get_warnings(self):
         columns = ["step", "message", "category", "filename", "lineno"]
-        if self.db_conn is not None and self.schema_version >= 2:
+        if self.schema_version >= 2:
             columns.insert(0, "rank")
 
         from pytools.datatable import DataTable
         result = DataTable(columns)
 
-        if self.db_conn is not None:
-            for row in self.db_conn.execute(
-                    "select %s from warnings" % (", ".join(columns))):
-                result.insert_row(row)
+        for row in self.db_conn.execute(
+                "select %s from warnings" % (", ".join(columns))):
+            result.insert_row(row)
 
         return result
 
@@ -455,18 +445,17 @@ class LogManager(object):
         existed = name in self.constants
         self.constants[name] = value
 
-        if self.db_conn is not None:
-            from pickle import dumps
-            value = buffer(dumps(value))
+        from pickle import dumps
+        value = buffer(dumps(value))
 
-            if existed:
-                self.db_conn.execute("update constants set value = ? where name = ?",
-                        (value, name))
-            else:
-                self.db_conn.execute("insert into constants values (?,?)",
-                        (name, value))
+        if existed:
+            self.db_conn.execute("update constants set value = ? where name = ?",
+                    (value, name))
+        else:
+            self.db_conn.execute("insert into constants values (?,?)",
+                    (name, value))
 
-            self.db_conn.commit()
+        self.db_conn.commit()
 
     def tick(self):
         """Record data points from each added L{LogQuantity}.
@@ -480,15 +469,14 @@ class LogManager(object):
             if value is None:
                 return
 
-            self.get_table(name).insert_row(
-                (self.tick_count, self.rank, value))
-            if self.db_conn is not None:
-                try:
-                    self.db_conn.execute("insert into %s values (?,?,?)" % name,
-                            (self.tick_count, self.rank, float(value)))
-                except:
-                    print "while adding datapoint for '%s':" % name
-                    raise
+            self.last_values[name] = value
+
+            try:
+                self.db_conn.execute("insert into %s values (?,?,?)" % name,
+                        (self.tick_count, self.rank, float(value)))
+            except:
+                print "while adding datapoint for '%s':" % name
+                raise
 
         for gd in self.gather_descriptors:
             if self.tick_count % gd.interval == 0:
@@ -517,13 +505,12 @@ class LogManager(object):
         self.t_log = time() - tick_start_time
 
     def save(self):
-        if self.db_conn is not None:
-            from sqlite3 import OperationalError
-            try:
-                self.db_conn.commit()
-            except OperationalError, e:
-                from warnings import warn
-                warn("encountered sqlite error during commit: %s" % e)
+        from sqlite3 import OperationalError
+        try:
+            self.db_conn.commit()
+        except OperationalError, e:
+            from warnings import warn
+            warn("encountered sqlite error during commit: %s" % e)
 
         self.last_save_time = time()
 
@@ -532,18 +519,16 @@ class LogManager(object):
         def add_internal(name, unit, description, def_agg):
             if name in self.quantity_data:
                 raise RuntimeError("cannot add the same quantity '%s' twice" % name)
-            self.quantity_data[name] = _QuantityData(
-                    unit, description, def_agg)
+            self.quantity_data[name] = _QuantityData(unit, description, def_agg)
 
-            if self.db_conn is not None:
-                from pickle import dumps
-                self.db_conn.execute("""insert into quantities values (?,?,?,?)""", (
-                      name, unit, description,
-                      buffer(dumps(def_agg))))
-                self.db_conn.execute("""create table %s
-                  (step integer, rank integer, value real)""" % name)
+            from pickle import dumps
+            self.db_conn.execute("""insert into quantities values (?,?,?,?)""", (
+                  name, unit, description,
+                  buffer(dumps(def_agg))))
+            self.db_conn.execute("""create table %s
+              (step integer, rank integer, value real)""" % name)
 
-                self.db_conn.commit()
+            self.db_conn.commit()
 
         self.gather_descriptors.append(_GatherDescriptor(quantity, interval))
 
@@ -791,13 +776,7 @@ class LogManager(object):
         if not self.have_nonlocal_watches and self.rank != self.head_rank:
             return
 
-        def get_last_value(table):
-            if table:
-                return table.data[-1][2]
-            else:
-                return 0
-
-        data_block = dict((qname, get_last_value(self.get_table(qname)))
+        data_block = dict((qname, self.last_values.get(qname, 0))
                 for qname in self.quantity_data.iterkeys())
 
         if self.mpi_comm is not None and self.have_nonlocal_watches:
@@ -823,9 +802,7 @@ class LogManager(object):
 
             if self.watches:
                 print " | ".join(
-                        compute_watch_str(watch)
-                        for watch in self.watches
-                        )
+                        compute_watch_str(watch) for watch in self.watches)
 
         ticks_per_sec = self.tick_count/max(1, time()-self.start_time)
         self.next_watch_tick = self.tick_count + int(max(1, ticks_per_sec))
-- 
GitLab