diff --git a/src/log.py b/src/log.py index 921fe1d7caa3ec33b97cdece1f4138d38e1997fa..88b6fde3222996d5473eb968839e01cf9a11105b 100644 --- a/src/log.py +++ b/src/log.py @@ -100,17 +100,11 @@ class _GatherDescriptor(object): class _QuantityData(object): - def __init__(self, unit, description, default_aggregator, table=None): + def __init__(self, unit, description, default_aggregator): self.unit = unit self.description = description self.default_aggregator = default_aggregator - if table is None: - from pytools.datatable import DataTable - self.table = DataTable(["step", "rank", "value"]) - else: - self.table = table.copy() - @@ -174,6 +168,7 @@ class LogManager(object): synchronized to the head node, which then writes them out to disk. """ self.quantity_data = {} + self.quantity_table = {} self.gather_descriptors = [] self.tick_count = 0 @@ -239,9 +234,23 @@ class LogManager(object): qdat = self.quantity_data[name] = _QuantityData( unit, description, loads(def_agg)) - for row in self.db_conn.execute( - "select step, rank, value from %s" % name): - qdat.table.insert_row(row) + + def get_table(self, q_name): + if q_name not in self.quantity_data: + raise KeyError, "invalid quantity name '%s'" % q_name + + try: + return self.quantity_table[q_name] + except KeyError: + from pytools.datatable import DataTable + result = self.quantity_table[q_name] = DataTable(["step", "rank", "value"]) + + if self.db_conn is not None: + for row in self.db_conn.execute( + "select step, rank, value from %s" % q_name): + result.insert_row(row) + + return result def add_watches(self, watches): """Add quantities that are printed after every time step.""" @@ -287,7 +296,7 @@ class LogManager(object): start_time = time() def insert_datapoint(name, value): - self.quantity_data[name].table.insert_row( + self.get_table(name).insert_row( (self.tick_count, self.rank, value)) if self.db_conn is not None: self.db_conn.execute("insert into %s values (?,?,?)" % name, @@ -340,7 +349,7 @@ class LogManager(object): if self.mpi_comm.rank == self.head_rank: for rank_data in gather(self.mpi_comm, None, self.head_rank)[1:]: for name, rows in rank_data: - self.quantity_data[name].table.insert_rows(rows) + self.get_table(name).insert_rows(rows) if self.db_conn is not None: for row in rows: self.db_conn.execute("insert into ? values (?,?,?)", @@ -351,13 +360,13 @@ class LogManager(object): else: # send non-head data away gather(self.mpi_comm, - [(name, qdat.table.data) + [(name, self.get_table(name).data) for name, qdat in self.quantity_data.iteritems()], self.head_rank) # and erase it - for qdat in self.quantity_data.itervalues(): - qdat.table.clear() + for qname in self.quantity_data.iterkeys(): + self.get_table(qname).clear() def add_quantity(self, quantity, interval=1): """Add an object derived from L{LogQuantity} to this manager.""" @@ -411,7 +420,7 @@ class LogManager(object): # aggregate table data for dd in dep_data: - table = self.quantity_data[dd.name].table + table = self.get_table(dd.name) table.sort(["step"]) dd.table = table.aggregated(["step"], "value", dd.agg_func).data @@ -608,8 +617,8 @@ class LogManager(object): else: return 0 - data_block = dict((name, get_last_value(qdat.table)) - for name, qdat in self.quantity_data.iteritems()) + data_block = dict((name, get_last_value(self.get_table(qname))) + for qname in self.quantity_data.iterkeys()) if self.mpi_comm is not None: from boost.mpi import broadcast, gather