From 90b3c2c184738e7723b70a96c583ec7120a19222 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Sun, 27 Sep 2009 00:08:29 -0400 Subject: [PATCH] LogManager: Don't keep data in memory. (d'oh) --- pytools/log.py | 189 ++++++++++++++++++++++--------------------------- 1 file changed, 83 insertions(+), 106 deletions(-) diff --git a/pytools/log.py b/pytools/log.py index dc85e2b..a02b926 100644 --- a/pytools/log.py +++ b/pytools/log.py @@ -247,7 +247,7 @@ class LogManager(object): assert mode in ["w", "r"], "invalid mode" self.quantity_data = {} - self.quantity_table = {} + self.last_values = {} self.gather_descriptors = [] self.tick_count = 0 @@ -276,53 +276,52 @@ class LogManager(object): self.have_nonlocal_watches = False # database binding - if filename is not None: + try: + import sqlite3 as sqlite + except ImportError: try: - import sqlite3 as sqlite + from pysqlite2 import dbapi2 as sqlite except ImportError: - try: - from pysqlite2 import dbapi2 as sqlite - except ImportError: - raise ImportError, "could not find a usable version of sqlite." + raise ImportError, "could not find a usable version of sqlite." + if filename is None: + filename = ":memory:" + else: if self.is_parallel: filename += "-rank%d" % self.rank - self.db_conn = sqlite.connect(filename, timeout=30) - self.mode = mode - try: - self.db_conn.execute("select * from quantities;") - except sqlite.OperationalError: - # we're building a new database - if mode == "r": - raise RuntimeError, "Log database '%s' not found" % filename - - self.schema_version = _set_up_schema(self.db_conn) - self.set_constant("schema_version", self.schema_version) - - self.set_constant("is_parallel", self.is_parallel) - - # set globally unique run_id - if self.is_parallel: - from boostmpi import broadcast - self.set_constant("unique_run_id", - broadcast(self.mpi_comm, _get_unique_id(), - root=self.head_rank)) - else: - self.set_constant("unique_run_id", _get_unique_id()) + self.db_conn = sqlite.connect(filename, timeout=30) + self.mode = mode + try: + self.db_conn.execute("select * from quantities;") + except sqlite.OperationalError: + # we're building a new database + if mode == "r": + raise RuntimeError, "Log database '%s' not found" % filename - if self.is_parallel: - self.set_constant("rank_count", self.mpi_comm.size) - else: - self.set_constant("rank_count", 1) + self.schema_version = _set_up_schema(self.db_conn) + self.set_constant("schema_version", self.schema_version) + + self.set_constant("is_parallel", self.is_parallel) + + # set globally unique run_id + if self.is_parallel: + from boostmpi import broadcast + self.set_constant("unique_run_id", + broadcast(self.mpi_comm, _get_unique_id(), + root=self.head_rank)) else: - # we've opened an existing database - if mode == "w": - raise RuntimeError, "Log database '%s' already exists" % filename - self._load() - else: - self.db_conn = None + self.set_constant("unique_run_id", _get_unique_id()) + if self.is_parallel: + self.set_constant("rank_count", self.mpi_comm.size) + else: + self.set_constant("rank_count", 1) + else: + # we've opened an existing database + if mode == "w": + raise RuntimeError, "Log database '%s' already exists" % filename + self._load() self.old_showwarning = None if capture_warnings: @@ -336,9 +335,7 @@ class LogManager(object): # cater to Python 2.5 and earlier self.old_showwarning(message, category, filename, lineno) - if (self.db_conn is not None - and self.schema_version >= 1 - and self.mode == "w"): + if self.schema_version >= 1 and self.mode == "w": if self.schema_version >= 2: self.db_conn.execute("insert into warnings values (?,?,?,?,?,?)", (self.rank, self.tick_count, str(message), str(category), @@ -385,40 +382,33 @@ class LogManager(object): self.capture_warnings(False) self.save() - - if self.db_conn is not None: - self.db_conn.close() + self.db_conn.close() def get_table(self, q_name): if q_name not in self.quantity_data: raise KeyError, "invalid quantity name '%s'" % q_name - try: - return self.quantity_table[q_name] - except KeyError: - from pytools.datatable import DataTable - result = self.quantity_table[q_name] = DataTable(["step", "rank", "value"]) + from pytools.datatable import DataTable + result = DataTable(["step", "rank", "value"]) - if self.db_conn is not None: - for row in self.db_conn.execute( - "select step, rank, value from %s" % q_name): - result.insert_row(row) + for row in self.db_conn.execute( + "select step, rank, value from %s" % q_name): + result.insert_row(row) - return result + return result def get_warnings(self): columns = ["step", "message", "category", "filename", "lineno"] - if self.db_conn is not None and self.schema_version >= 2: + if self.schema_version >= 2: columns.insert(0, "rank") from pytools.datatable import DataTable result = DataTable(columns) - if self.db_conn is not None: - for row in self.db_conn.execute( - "select %s from warnings" % (", ".join(columns))): - result.insert_row(row) + for row in self.db_conn.execute( + "select %s from warnings" % (", ".join(columns))): + result.insert_row(row) return result @@ -455,18 +445,17 @@ class LogManager(object): existed = name in self.constants self.constants[name] = value - if self.db_conn is not None: - from pickle import dumps - value = buffer(dumps(value)) + from pickle import dumps + value = buffer(dumps(value)) - if existed: - self.db_conn.execute("update constants set value = ? where name = ?", - (value, name)) - else: - self.db_conn.execute("insert into constants values (?,?)", - (name, value)) + if existed: + self.db_conn.execute("update constants set value = ? where name = ?", + (value, name)) + else: + self.db_conn.execute("insert into constants values (?,?)", + (name, value)) - self.db_conn.commit() + self.db_conn.commit() def tick(self): """Record data points from each added L{LogQuantity}. @@ -480,15 +469,14 @@ class LogManager(object): if value is None: return - self.get_table(name).insert_row( - (self.tick_count, self.rank, value)) - if self.db_conn is not None: - try: - self.db_conn.execute("insert into %s values (?,?,?)" % name, - (self.tick_count, self.rank, float(value))) - except: - print "while adding datapoint for '%s':" % name - raise + self.last_values[name] = value + + try: + self.db_conn.execute("insert into %s values (?,?,?)" % name, + (self.tick_count, self.rank, float(value))) + except: + print "while adding datapoint for '%s':" % name + raise for gd in self.gather_descriptors: if self.tick_count % gd.interval == 0: @@ -517,13 +505,12 @@ class LogManager(object): self.t_log = time() - tick_start_time def save(self): - if self.db_conn is not None: - from sqlite3 import OperationalError - try: - self.db_conn.commit() - except OperationalError, e: - from warnings import warn - warn("encountered sqlite error during commit: %s" % e) + from sqlite3 import OperationalError + try: + self.db_conn.commit() + except OperationalError, e: + from warnings import warn + warn("encountered sqlite error during commit: %s" % e) self.last_save_time = time() @@ -532,18 +519,16 @@ class LogManager(object): def add_internal(name, unit, description, def_agg): if name in self.quantity_data: raise RuntimeError("cannot add the same quantity '%s' twice" % name) - self.quantity_data[name] = _QuantityData( - unit, description, def_agg) + self.quantity_data[name] = _QuantityData(unit, description, def_agg) - if self.db_conn is not None: - from pickle import dumps - self.db_conn.execute("""insert into quantities values (?,?,?,?)""", ( - name, unit, description, - buffer(dumps(def_agg)))) - self.db_conn.execute("""create table %s - (step integer, rank integer, value real)""" % name) + from pickle import dumps + self.db_conn.execute("""insert into quantities values (?,?,?,?)""", ( + name, unit, description, + buffer(dumps(def_agg)))) + self.db_conn.execute("""create table %s + (step integer, rank integer, value real)""" % name) - self.db_conn.commit() + self.db_conn.commit() self.gather_descriptors.append(_GatherDescriptor(quantity, interval)) @@ -791,13 +776,7 @@ class LogManager(object): if not self.have_nonlocal_watches and self.rank != self.head_rank: return - def get_last_value(table): - if table: - return table.data[-1][2] - else: - return 0 - - data_block = dict((qname, get_last_value(self.get_table(qname))) + data_block = dict((qname, self.last_values.get(qname, 0)) for qname in self.quantity_data.iterkeys()) if self.mpi_comm is not None and self.have_nonlocal_watches: @@ -823,9 +802,7 @@ class LogManager(object): if self.watches: print " | ".join( - compute_watch_str(watch) - for watch in self.watches - ) + compute_watch_str(watch) for watch in self.watches) ticks_per_sec = self.tick_count/max(1, time()-self.start_time) self.next_watch_tick = self.tick_count + int(max(1, ticks_per_sec)) -- GitLab