diff --git a/src/log.py b/src/log.py index 4f50046741320c1d79b9f82b0f352cd4c151d634..f0639895fd4fe9630d3c9aa7699f7f5794d137ce 100644 --- a/src/log.py +++ b/src/log.py @@ -96,14 +96,16 @@ class LogManager: raise IOError, "cowardly refusing to overwrite '%s'" % self.filename from time import time - self.last_checkpoint = time() + self.start_time = time() self.mpi_comm = mpi_comm if mpi_comm is None: self.rank = 0 + self.last_checkpoint = self.start_time else: self.rank = mpi_comm.rank - self.last_sync = self.last_checkpoint + self.next_sync = 10 + self.head_rank = 0 self.t_log = 0 @@ -124,15 +126,26 @@ class LogManager: # synchronize logs with parallel peers, if necessary if self.mpi_comm is not None: - if end_time - self.last_sync > 10: - self.synchronize_logs() - self.last_sync = end_time + # parallel-case : sync, then checkpoint + if self.tick_count == self.next_sync: - # checkpoint log to disk, if necessary - if self.filename is not None: - if end_time - self.last_checkpoint > 10: - self.save() - self.last_checkpoint = end_time + if self.filename is not None: + # implicitly synchronizes + self.save() + else: + self.synchronize_logs() + + # figure out next sync tick + ticks_per_20_sec = 20*self.tick_count/max(1, end_time-self.start_time) + self.next_sync = self.tick_count + min(10, ticks_per_20_sec) + from boost.mpi import broadcast + self.next_sync = broadcast(self.mpi_comm, self.next_sync, self.head_rank) + else: + # non-parallel-case : checkpoint log to disk, if necessary + if self.filename is not None: + if end_time - self.last_checkpoint > 10: + self.save() + self.last_checkpoint = end_time def synchronize_logs(self): """Send logs to head node.""" @@ -140,16 +153,20 @@ class LogManager: return from boost.mpi import gather - root = 0 - if self.mpi_comm.rank == root: - for rank_data in gather(self.mpi_comm, None, root)[1:]: + if self.mpi_comm.rank == self.head_rank: + for rank_data in gather(self.mpi_comm, None, self.head_rank)[1:]: for name, rows in rank_data: - self.quantity_data[name].insert_rows(rows) + self.quantity_data[name].table.insert_rows(rows) else: + # send non-head data away gather(self.mpi_comm, [(name, qdat.table.data) for name, qdat in self.quantity_data.iteritems()], - root) + self.head_rank) + + # and erase it + for qdat in self.quantity_data.itervalues(): + qdat.table.clear() def add_quantity(self, quantity, interval=1): """Add an object derived from L{LogQuantity} to this manager.""" @@ -185,6 +202,12 @@ class LogManager: elif agg_name == "avg": from pytools import average agg_func = average + elif agg_name == "sum": + agg_func = sum + elif agg_name == "norm2": + from math import sqrt + agg_func = lambda iterable: sqrt( + sum(entry**2 for entry in iterable)) else: raise ValueError, "invalid rank aggregator '%s'" % agg_name elif isinstance(dep, Subscript): @@ -274,7 +297,7 @@ class LogManager: def save(self, filename=None): self.synchronize_logs() - if self.mpi_comm and not self.mpi_comm.rank != 0: + if self.mpi_comm and self.mpi_comm.rank != self.head_rank: return if filename is not None: @@ -300,7 +323,7 @@ class LogManager: open(filename, "w"), protocol=HIGHEST_PROTOCOL) def load(self, filename): - if self.mpi_comm and not self.mpi_comm.rank != 0: + if self.mpi_comm and self.mpi_comm.rank != self.head_rank: return from cPickle import load