From d21c7655133a7262e8d1137ca25b38b16bbb046a Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Wed, 12 Mar 2008 23:43:32 -0400
Subject: [PATCH] Enable automatic logging of warnings

---
 bin/logtool |  5 +++++
 src/log.py  | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 61 insertions(+), 2 deletions(-)

diff --git a/bin/logtool b/bin/logtool
index 5c8cd37..f5d303a 100755
--- a/bin/logtool
+++ b/bin/logtool
@@ -13,6 +13,7 @@ following:
 "datafile outfile expr_x,expr_y" to write out a data file.
 "table variable" to print the full data table for a time series variable.
 "prefix string" to set the legend prefix for all following plot commands.
+"warnings" to list the warnings that were issued during the logged run.
 """
     parser = OptionParser(usage="%prog FILE COMMANDS FILE COMMANDS...",
             description=description)
@@ -134,6 +135,10 @@ following:
                     hold=True,  **kwargs)
 
             did_plot = True
+        elif cmd == "warnings":
+            check_no_file()
+            print logmgr.get_warnings()
+
         elif cmd == "datafile":
             check_no_file()
 
diff --git a/src/log.py b/src/log.py
index 3744971..f308947 100644
--- a/src/log.py
+++ b/src/log.py
@@ -159,7 +159,7 @@ class LogManager(object):
     data in a saved log.
     """
 
-    def __init__(self, filename, mode, mpi_comm=None):
+    def __init__(self, filename, mode, mpi_comm=None, capture_warnings=True):
         """Initialize this log manager instance.
 
         @arg filename: If given, the filename to which this log is bound.
@@ -168,6 +168,8 @@ class LogManager(object):
           database is initially empty.
         @arg mpi_comm: A C{boost.mpi} communicator. If given, logs are periodically
           synchronized to the head node, which then writes them out to disk.
+        @arg capture_warnings: Tap the Python warnings facility and save warnings
+          to the log file.
         """
 
         assert isinstance(mode, basestring), "mode must be a string"
@@ -207,6 +209,7 @@ class LogManager(object):
             import sqlite3
 
             self.db_conn = sqlite3.connect(filename, timeout=30)
+            self.mode = mode
             try:
                 self.db_conn.execute("select * from quantities;")
                 if mode == "w":
@@ -227,10 +230,49 @@ class LogManager(object):
                   create table constants (
                     name text, 
                     value blob)""")
+                self.db_conn.execute("""
+                  create table warnings (
+                    step integer,
+                    message text, 
+                    category text,
+                    filename text,
+                    lineno integer
+                    )""")
                 self.set_constant("is_parallel", self.is_parallel)
+                self.schema_version = 1
+                self.set_constant("schema_version", self.schema_version)
         else:
             self.db_conn = None
 
+        self.old_showwarning = None
+        if capture_warnings:
+            self.capture_warnings(True)
+
+    def capture_warnings(self, enable=True):
+        import warnings
+        if enable:
+            if self.old_showwarning is None:
+                pass
+                self.old_showwarning = warnings.showwarning
+                warnings.showwarning = self._showwarning
+            else:
+                raise RuntimeError, "Warnings capture was enabled twice"
+        else:
+            if self.old_showwarning is None:
+                raise RuntimeError, "Warnings capture was disabled, but never enabled"
+            else:
+                warnings.showwarning = self.old_showwarning
+                self.old_showwarning = None
+
+    def _showwarning(self, message, category, filename, lineno):
+        self.old_showwarning(message, category, filename, lineno)
+
+        if (self.schema_version >= 1 
+                and self.mode == "w" 
+                and self.db_conn is not None):
+            self.db_conn.execute("insert into warnings values (?,?,?,?,?)",
+                    (self.tick_count, message.message, str(category), filename, lineno))
+
     def _load(self):
         if self.mpi_comm and self.mpi_comm.rank != self.head_rank:
             return
@@ -239,6 +281,8 @@ class LogManager(object):
         for name, value in self.db_conn.execute("select name, value from constants"):
             self.constants[name] = loads(value)
 
+        self.schema_version = self.constants.get("schema_version", 0)
+
         self.is_parallel = self.constants["is_parallel"]
 
         for name, unit, description, def_agg in self.db_conn.execute(
@@ -246,7 +290,6 @@ class LogManager(object):
             qdat = self.quantity_data[name] = _QuantityData(
                     unit, description, loads(def_agg))
 
-
     def get_table(self, q_name):
         if q_name not in self.quantity_data:
             raise KeyError, "invalid quantity name '%s'" % q_name
@@ -264,6 +307,17 @@ class LogManager(object):
 
             return result
 
+    def get_warnings(self):
+        from pytools.datatable import DataTable
+        result = DataTable(["step", "message", "category", "filename", "lineno"])
+
+        if self.schema_version >= 1 and self.db_conn is not None:
+            for row in self.db_conn.execute(
+                    "select step, message, category, filename, lineno from warnings"):
+                result.insert_row(row)
+
+        return result
+
     def add_watches(self, watches):
         """Add quantities that are printed after every time step."""
 
-- 
GitLab