From 2721385374a3a98dfe64386dbed3ac39e1a825ba Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Mon, 16 Aug 2021 15:39:13 -0500 Subject: [PATCH] allow choosing table type in EOCRecorder --- pytools/__init__.py | 7 ++++--- pytools/convergence.py | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/pytools/__init__.py b/pytools/__init__.py index 386a4c1..6d2bdfc 100644 --- a/pytools/__init__.py +++ b/pytools/__init__.py @@ -1569,8 +1569,9 @@ class Table: .. automethod:: add_row .. automethod:: __str__ - .. automethod:: latex .. automethod:: github_markdown + .. automethod:: csv + .. automethod:: latex """ def __init__(self, alignments=None): @@ -1697,8 +1698,8 @@ class Table: if hline_after is None: hline_after = [] lines = [] - for row_nr, row in list(enumerate(self.rows))[skip_lines:]: - lines.append(" & ".join(row)+r" \\") + for row_nr, row in enumerate(self.rows[skip_lines:]): + lines.append(f"{' & '.join(row)} \\") if row_nr in hline_after: lines.append(r"\hline") diff --git a/pytools/convergence.py b/pytools/convergence.py index ae2a5c9..a9e2dd4 100644 --- a/pytools/convergence.py +++ b/pytools/convergence.py @@ -52,7 +52,7 @@ class EOCRecorder: def max_error(self): return max(err for absc, err in self.history) - def pretty_print(self, + def _to_table(self, *, abscissa_label="h", error_label="Error", gliding_mean=2, @@ -75,11 +75,39 @@ class EOCRecorder: tbl.add_row((absc_str, err_str, eoc_str)) + return tbl + + def pretty_print(self, *, + abscissa_label="h", + error_label="Error", + gliding_mean=2, + abscissa_format="%s", + error_format="%s", + eoc_format="%s", + table_type="markdown"): + tbl = self._to_table( + abscissa_label=abscissa_label, error_label=error_label, + abscissa_format=abscissa_format, + error_format=error_format, + eoc_format=eoc_format, + gliding_mean=gliding_mean) + + if table_type == "markdown": + tbl_str = tbl.github_markdown() + elif table_type == "latex": + tbl_str = tbl.latex() + elif table_type == "ascii": + tbl_str = str(tbl) + elif table_type == "csv": + tbl_str = tbl.csv() + else: + raise ValueError(f"unknown table type: {table_type}") + if len(self.history) > 1: - return "{}\n\nOverall EOC: {}".format(str(tbl), + return "{}\n\nOverall EOC: {}".format(tbl_str, self.estimate_order_of_convergence()[0, 1]) else: - return str(tbl) + return tbl_str def __str__(self): return self.pretty_print() -- GitLab