diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8a38a68a48dcb1612815cd765ae8c57c02d5658f..cfc276fbeca6bda76358407e2524a22ebb8a5e4f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,6 +61,7 @@ jobs: python-version: '3.x' - name: "Main Script" run: | + EXTRA_INSTALL="numpy" curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-mypy.sh . ./prepare-and-run-mypy.sh python3 mypy diff --git a/pytools/convergence.py b/pytools/convergence.py index ad5fa6d9c62b835fc1f3ddca96d3fd3c3a39659a..e302ac72e3c9f62611373257a2805af721ba5d53 100644 --- a/pytools/convergence.py +++ b/pytools/convergence.py @@ -1,6 +1,7 @@ from typing import List, Optional, Tuple import numpy as np +import numbers # {{{ eoc estimation -------------------------------------------------------------- @@ -33,16 +34,16 @@ class EOCRecorder: .. automethod:: write_gnuplot_file """ - def __init__(self): + def __init__(self) -> None: self.history: List[Tuple[float, float]] = [] def add_data_point(self, abscissa: float, error: float) -> None: - if not (np.isscalar(abscissa) + if not (isinstance(abscissa, numbers.Number) or (isinstance(abscissa, np.ndarray) and abscissa.shape == ())): raise TypeError( f"'abscissa' is not a scalar: '{type(abscissa).__name__}'") - if not (np.isscalar(error) + if not (isinstance(error, numbers.Number) or (isinstance(error, np.ndarray) and error.shape == ())): raise TypeError(f"'error' is not a scalar: '{type(error).__name__}'")