diff --git a/pytools/log.py b/pytools/log.py index ea4688baf57b7bef2367d39d7a1b8c6ef3cd40a8..6f484d7969da8a37075814d93f689d23a234cb9e 100644 --- a/pytools/log.py +++ b/pytools/log.py @@ -271,6 +271,7 @@ class LogManager(object): # watch stuff self.watches = [] self.next_watch_tick = 1 + self.have_nonlocal_watches = False # database binding if filename is not None: @@ -420,6 +421,10 @@ class LogManager(object): parsed = self._parse_expr(expr) parsed, dep_data = self._get_expr_dep_data(parsed) + from pytools import any + self.have_nonlocal_watches = self.have_nonlocal_watches or \ + any(dd.nonlocal_agg for dd in dep_data) + from pymbolic import compile compiled = compile(parsed, [dd.varname for dd in dep_data]) @@ -535,7 +540,8 @@ class LogManager(object): Aggregators are specified as follows: - C{qty.min}, C{qty.max}, C{qty.avg}, C{qty.sum}, C{qty.norm2} - - C{qty[rank_nbr] + - C{qty[rank_nbr]} + - C{qty.loc} """ parsed = self._parse_expr(expression) @@ -678,6 +684,13 @@ class LogManager(object): return parsed def _get_expr_dep_data(self, parsed): + class Nth: + def __init__(self, n): + self.n = n + + def __call__(self, lst): + return lst[self.n] + from pymbolic.mapper.dependency import DependencyMapper deps = DependencyMapper()(parsed) @@ -686,6 +699,8 @@ class LogManager(object): dep_data = [] from pymbolic.primitives import Variable, Lookup, Subscript for dep_idx, dep in enumerate(deps): + nonlocal_agg = True + if isinstance(dep, Variable): name = dep.name agg_func = self.quantity_data[name].default_aggregator @@ -698,7 +713,11 @@ class LogManager(object): assert isinstance(dep.aggregate, Variable) name = dep.aggregate.name agg_name = dep.name - if agg_name == "min": + + if agg_name == "loc": + agg_func = Nth(self.rank) + nonlocal_agg = False + elif agg_name == "min": agg_func = min elif agg_name == "max": agg_func = max @@ -717,13 +736,6 @@ class LogManager(object): assert isinstance(dep.aggregate, Variable) name = dep.aggregate.name - class Nth: - def __init__(self, n): - self.n = n - - def __call__(self, lst): - return lst[self.n] - from pymbolic import evaluate agg_func = Nth(evaluate(dep.index)) @@ -733,7 +745,8 @@ class LogManager(object): class DependencyData(Record): pass this_dep_data = DependencyData(name=name, qdat=qdat, agg_func=agg_func, - varname="logvar%d" % dep_idx, expr=dep) + varname="logvar%d" % dep_idx, expr=dep, + nonlocal_agg=nonlocal_agg) dep_data.append(this_dep_data) # substitute in the "logvar" variable names @@ -744,6 +757,9 @@ class LogManager(object): return parsed, dep_data def _watch_tick(self): + if not self.have_nonlocal_watches and self.rank != self.head_rank: + return + def get_last_value(table): if table: return table.data[-1][2] @@ -753,7 +769,7 @@ class LogManager(object): data_block = dict((qname, get_last_value(self.get_table(qname))) for qname in self.quantity_data.iterkeys()) - if self.mpi_comm is not None: + if self.mpi_comm is not None and self.have_nonlocal_watches: from boostmpi import broadcast, gather gathered_data = gather(self.mpi_comm, data_block, self.head_rank) @@ -783,7 +799,7 @@ class LogManager(object): ticks_per_sec = self.tick_count/max(1, time()-self.start_time) self.next_watch_tick = self.tick_count + int(max(1, ticks_per_sec)) - if self.mpi_comm is not None: + if self.mpi_comm is not None and self.have_nonlocal_watches: self.next_watch_tick = broadcast(self.mpi_comm, self.next_watch_tick, self.head_rank)