diff --git a/src/__init__.py b/src/__init__.py index d0dcd640f093a951faf1ee55583f8d3ba2f3239a..ab2b0bcb0a18975974400ca9f33c576e4931a57f 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -25,6 +25,15 @@ def factorial(n): +def norm_2(iterable): + return sum(i**2 for i in iterable)**0.5 + +def norm_p(iterable): + return sum(i**p for i in iterable)**(1/p) + + + + # Data structures ------------------------------------------------------------ class Record(object): def __init__(self, valuedict=None, exclude=["self"], **kwargs): diff --git a/src/log.py b/src/log.py index 36130db3be92f987d1a18b1a17279355d94a3211..f2f293712654cdd42084023a638f6a5475a71690 100644 --- a/src/log.py +++ b/src/log.py @@ -19,8 +19,9 @@ class LogQuantity: self.name = name self.unit = unit self.description = description - - default_aggregator = None + + @property + def default_aggregator(self): return None def __call__(self): """Return the current value of the diagnostic represented by this @@ -44,10 +45,12 @@ class CallableLogQuantityAdapter(LogQuantity): # manager functionality ------------------------------------------------------- class _QuantityData: - def __init__(self, quantity, interval=1, table=None): + def __init__(self, quantity, interval=1, table=None, default_aggregator=None): self.quantity = quantity self.interval = interval + self.default_aggregator = default_aggregator or quantity.default_aggregator + if table is None: from pytools.datatable import DataTable self.table = DataTable(["step", "rank", "value"]) @@ -337,13 +340,15 @@ class LogManager: save_buffers = dict( (name, _QuantityData( LogQuantity( - qbuf.quantity.name, - qbuf.quantity.unit, - qbuf.quantity.description, + qdat.quantity.name, + qdat.quantity.unit, + qdat.quantity.description, ), - qbuf.interval, - qbuf.table)) - for name, qbuf in self.quantity_data.iteritems()) + qdat.interval, + qdat.table, + qdat.default_aggregator, + )) + for name, qdat in self.quantity_data.iteritems()) from cPickle import dump, HIGHEST_PROTOCOL dump((save_buffers, self.constants, self.is_parallel), @@ -429,12 +434,12 @@ class LogManager: for dep_idx, dep in enumerate(deps): if isinstance(dep, Variable): name = dep.name - agg_func = self.quantity_data[name].quantity.default_aggregator + agg_func = self.quantity_data[name].default_aggregator if agg_func is None: if self.is_parallel: raise ValueError, "must specify explicit aggregator for '%s'" % name else: - agg_func = max # use something simple + agg_func = lambda lst: lst[0] elif isinstance(dep, Lookup): assert isinstance(dep.aggregate, Variable) name = dep.aggregate.name @@ -471,7 +476,8 @@ class LogManager: quantity = self.quantity_data[name].quantity from pytools import Record - this_dep_data = Record(name=name, quantity=quantity, agg_func=agg_func, + this_dep_data = Record(name=name, quantity=quantity, + agg_func=agg_func, varname="logvar%d" % dep_idx, expr=dep) dep_data.append(this_dep_data) @@ -515,7 +521,8 @@ class LogManager: self.next_watch_tick = self.tick_count + int(max(1, ticks_per_sec)) if self.mpi_comm is not None: - self.next_watch_tick = broadcast(self.mpi_comm, next_watch_tick, self.head_rank) + self.next_watch_tick = broadcast(self.mpi_comm, + self.next_watch_tick, self.head_rank) @@ -550,8 +557,6 @@ class LogUpdateDuration(LogQuantity): LogQuantity.__init__(self, name, "s", "Time spent updating the log") self.log_manager = mgr - default_aggregator = max - def __call__(self): return self.log_manager.t_log @@ -585,8 +590,6 @@ class TimestepCounter(LogQuantity): LogQuantity.__init__(self, name, "1", "Timesteps") self.steps = 0 - default_aggregator = max - def __call__(self): result = self.steps self.steps += 1 @@ -603,8 +606,6 @@ class TimestepDuration(LogQuantity): self.last_start = time() - default_aggregator = max - def __call__(self): now = time() result = now - self.last_start @@ -646,8 +647,6 @@ class SimulationTime(LogQuantity): self.dt = dt self.t = 0 - default_aggregator = max - def set_dt(self, dt): self.dt = dt