From 8d2278589c297e4b7090f19b80eea2ddef326ba3 Mon Sep 17 00:00:00 2001
From: Matthias Diener <mdiener@illinois.edu>
Date: Sun, 13 Nov 2022 13:32:10 -0600
Subject: [PATCH] convergence: small mypy fixes (#157)

* convergence: small fixes

* fix types in add_data_point
---
 .github/workflows/ci.yml | 1 +
 pytools/convergence.py   | 7 ++++---
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 8a38a68..cfc276f 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -61,6 +61,7 @@ jobs:
                 python-version: '3.x'
         -   name: "Main Script"
             run: |
+                EXTRA_INSTALL="numpy"
                 curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/prepare-and-run-mypy.sh
                 . ./prepare-and-run-mypy.sh python3 mypy
 
diff --git a/pytools/convergence.py b/pytools/convergence.py
index ad5fa6d..e302ac7 100644
--- a/pytools/convergence.py
+++ b/pytools/convergence.py
@@ -1,6 +1,7 @@
 from typing import List, Optional, Tuple
 
 import numpy as np
+import numbers
 
 
 # {{{ eoc estimation --------------------------------------------------------------
@@ -33,16 +34,16 @@ class EOCRecorder:
     .. automethod:: write_gnuplot_file
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         self.history: List[Tuple[float, float]] = []
 
     def add_data_point(self, abscissa: float, error: float) -> None:
-        if not (np.isscalar(abscissa)
+        if not (isinstance(abscissa, numbers.Number)
                 or (isinstance(abscissa, np.ndarray) and abscissa.shape == ())):
             raise TypeError(
                     f"'abscissa' is not a scalar: '{type(abscissa).__name__}'")
 
-        if not (np.isscalar(error)
+        if not (isinstance(error, numbers.Number)
                 or (isinstance(error, np.ndarray) and error.shape == ())):
             raise TypeError(f"'error' is not a scalar: '{type(error).__name__}'")
 
-- 
GitLab