From 688d007cc105794ca005562559abc1c924c84c1c Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sat, 30 Oct 2021 23:57:54 -0500 Subject: [PATCH] fix check in EOCRecorder --- pytools/convergence.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytools/convergence.py b/pytools/convergence.py index 5b5128d..714a5d3 100644 --- a/pytools/convergence.py +++ b/pytools/convergence.py @@ -37,13 +37,14 @@ class EOCRecorder: self.history: List[Tuple[float, float]] = [] def add_data_point(self, abscissa: float, error: float) -> None: - from numbers import Number - if not isinstance(abscissa, Number): + if not (np.isscalar(abscissa) + or (isinstance(abscissa, np.ndarray) and abscissa.shape == ())): raise TypeError( - f"'abscissa' is not a number: '{type(abscissa).__name__}'") + f"'abscissa' is not a scalar: '{type(abscissa).__name__}'") - if not isinstance(error, Number): - raise TypeError(f"'error' is not a number: '{type(error).__name__}'") + if not (np.isscalar(error) + or (isinstance(error, np.ndarray) and error.shape == ())): + raise TypeError(f"'error' is not a scalar: '{type(error).__name__}'") self.history.append((abscissa, error)) -- GitLab