From 8d2278589c297e4b7090f19b80eea2ddef326ba3 Mon Sep 17 00:00:00 2001 From: Matthias Diener <mdiener@illinois.edu> Date: Sun, 13 Nov 2022 13:32:10 -0600 Subject: [PATCH] convergence: small mypy fixes (#157) * convergence: small fixes * fix types in add_data_point --- .github/workflows/ci.yml | 1 + pytools/convergence.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8a38a68..cfc276f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,6 +61,7 @@ jobs: python-version: '3.x' - name: "Main Script" run: | + EXTRA_INSTALL="numpy" curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-mypy.sh . ./prepare-and-run-mypy.sh python3 mypy diff --git a/pytools/convergence.py b/pytools/convergence.py index ad5fa6d..e302ac7 100644 --- a/pytools/convergence.py +++ b/pytools/convergence.py @@ -1,6 +1,7 @@ from typing import List, Optional, Tuple import numpy as np +import numbers # {{{ eoc estimation -------------------------------------------------------------- @@ -33,16 +34,16 @@ class EOCRecorder: .. automethod:: write_gnuplot_file """ - def __init__(self): + def __init__(self) -> None: self.history: List[Tuple[float, float]] = [] def add_data_point(self, abscissa: float, error: float) -> None: - if not (np.isscalar(abscissa) + if not (isinstance(abscissa, numbers.Number) or (isinstance(abscissa, np.ndarray) and abscissa.shape == ())): raise TypeError( f"'abscissa' is not a scalar: '{type(abscissa).__name__}'") - if not (np.isscalar(error) + if not (isinstance(error, numbers.Number) or (isinstance(error, np.ndarray) and error.shape == ())): raise TypeError(f"'error' is not a scalar: '{type(error).__name__}'") -- GitLab