diff --git a/pytools/convergence.py b/pytools/convergence.py index 5b5128d493ccc9a1b2dd8fb5ec2bc1b7d2d6b67c..714a5d3102763d5fa9838d690bf5c278012b5d06 100644 --- a/pytools/convergence.py +++ b/pytools/convergence.py @@ -37,13 +37,14 @@ class EOCRecorder: self.history: List[Tuple[float, float]] = [] def add_data_point(self, abscissa: float, error: float) -> None: - from numbers import Number - if not isinstance(abscissa, Number): + if not (np.isscalar(abscissa) + or (isinstance(abscissa, np.ndarray) and abscissa.shape == ())): raise TypeError( - f"'abscissa' is not a number: '{type(abscissa).__name__}'") + f"'abscissa' is not a scalar: '{type(abscissa).__name__}'") - if not isinstance(error, Number): - raise TypeError(f"'error' is not a number: '{type(error).__name__}'") + if not (np.isscalar(error) + or (isinstance(error, np.ndarray) and error.shape == ())): + raise TypeError(f"'error' is not a scalar: '{type(error).__name__}'") self.history.append((abscissa, error))