From 6da40e30250b77e9802e1515d0a7d3ec686da1ae Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Thu, 28 Oct 2021 21:38:56 -0500 Subject: [PATCH] improve EOCRecorder docs and checks --- pytools/convergence.py | 49 +++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/pytools/convergence.py b/pytools/convergence.py index 46d8b03..5b5128d 100644 --- a/pytools/convergence.py +++ b/pytools/convergence.py @@ -1,3 +1,5 @@ +from typing import List, Optional, Tuple + import numpy as np @@ -20,13 +22,34 @@ def estimate_order_of_convergence(abscissae, errors): class EOCRecorder: + """ + .. automethod:: add_data_point + + .. automethod:: estimate_order_of_convergence + .. automethod:: order_estimate + .. automethod:: max_error + + .. automethod:: pretty_print + .. automethod:: write_gnuplot_file + """ + def __init__(self): - self.history = [] + self.history: List[Tuple[float, float]] = [] + + def add_data_point(self, abscissa: float, error: float) -> None: + from numbers import Number + if not isinstance(abscissa, Number): + raise TypeError( + f"'abscissa' is not a number: '{type(abscissa).__name__}'") + + if not isinstance(error, Number): + raise TypeError(f"'error' is not a number: '{type(error).__name__}'") - def add_data_point(self, abscissa, error): self.history.append((abscissa, error)) - def estimate_order_of_convergence(self, gliding_mean=None): + def estimate_order_of_convergence(self, + gliding_mean: Optional[int] = None, + ) -> np.ndarray: abscissae = np.array([a for a, e in self.history]) errors = np.array([e for a, e in self.history]) @@ -46,10 +69,10 @@ class EOCRecorder: abscissae[i:i+gliding_mean], errors[i:i+gliding_mean]) return result - def order_estimate(self): + def order_estimate(self) -> float: return self.estimate_order_of_convergence()[0, 1] - def max_error(self): + def max_error(self) -> float: return max(err for absc, err in self.history) def _to_table(self, *, @@ -78,13 +101,13 @@ class EOCRecorder: return tbl def pretty_print(self, *, - abscissa_label="h", - error_label="Error", - gliding_mean=2, - abscissa_format="%s", - error_format="%s", - eoc_format="%s", - table_type="markdown"): + abscissa_label: str = "h", + error_label: str = "Error", + gliding_mean: int = 2, + abscissa_format: str = "%s", + error_format: str = "%s", + eoc_format: str = "%s", + table_type: str = "markdown") -> str: tbl = self._to_table( abscissa_label=abscissa_label, error_label=error_label, abscissa_format=abscissa_format, @@ -112,7 +135,7 @@ class EOCRecorder: def __str__(self): return self.pretty_print() - def write_gnuplot_file(self, filename): + def write_gnuplot_file(self, filename: str) -> None: outfile = open(filename, "w") for absc, err in self.history: outfile.write(f"{absc:f} {err:f}\n") -- GitLab