From b2e28015e82902103a29095d53a1df099a2affc7 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 11 Jan 2022 11:09:27 -0600
Subject: [PATCH] Tighten type information for from_numpy

---
 arraycontext/container/traversal.py    | 10 +++++++---
 arraycontext/context.py                |  3 ++-
 arraycontext/impl/pyopencl/__init__.py |  4 ++--
 arraycontext/impl/pytato/__init__.py   |  6 +++---
 4 files changed, 14 insertions(+), 9 deletions(-)

diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index 07c1544..b29dc86 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -62,6 +62,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
+from numbers import Number
 from typing import Any, Callable, Iterable, List, Optional, Union, Tuple
 from functools import update_wrapper, partial, singledispatch
 
@@ -732,13 +733,16 @@ def unflatten(
 
 # {{{ numpy conversion
 
-def from_numpy(ary: Any, actx: ArrayContext) -> Any:
+def from_numpy(
+        ary: Union[np.ndarray, np.generic, Number],
+        actx: ArrayContext) -> ArrayOrContainerT:
     """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer`
     to the base array type of :class:`~arraycontext.ArrayContext`.
 
     The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`.
     """
-    def _from_numpy_with_check(subary: Any) -> Any:
+    def _from_numpy_with_check(subary: Union[np.ndarray, np.generic, Number]) \
+            -> ArrayOrContainerT:
         if isinstance(subary, np.ndarray) or np.isscalar(subary):
             return actx.from_numpy(subary)
         else:
@@ -747,7 +751,7 @@ def from_numpy(ary: Any, actx: ArrayContext) -> Any:
     return rec_map_array_container(_from_numpy_with_check, ary)
 
 
-def to_numpy(ary: Any, actx: ArrayContext) -> Any:
+def to_numpy(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
     """Convert all arrays in the :class:`~arraycontext.ArrayContainer` to
     :mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*.
 
diff --git a/arraycontext/context.py b/arraycontext/context.py
index fa70513..aa3054d 100644
--- a/arraycontext/context.py
+++ b/arraycontext/context.py
@@ -125,6 +125,7 @@ from pytools.tag import Tag
 
 DeviceArray = Any
 DeviceScalar = Any
+_ScalarLike = Union[int, float, complex, np.generic]
 
 
 # {{{ ArrayContext
@@ -197,7 +198,7 @@ class ArrayContext(ABC):
         return self.zeros(shape=ary.shape, dtype=ary.dtype)
 
     @abstractmethod
-    def from_numpy(self, array: np.ndarray):
+    def from_numpy(self, array: Union[np.ndarray, _ScalarLike]):
         r"""
         :returns: the :class:`numpy.ndarray` *array* converted to the
             array context's array type. The returned array will be
diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py
index 585a99e..7336aa9 100644
--- a/arraycontext/impl/pyopencl/__init__.py
+++ b/arraycontext/impl/pyopencl/__init__.py
@@ -34,7 +34,7 @@ import numpy as np
 
 from pytools.tag import Tag
 
-from arraycontext.context import ArrayContext
+from arraycontext.context import ArrayContext, _ScalarLike
 
 
 if TYPE_CHECKING:
@@ -156,7 +156,7 @@ class PyOpenCLArrayContext(ArrayContext):
         return cl_array.zeros(self.queue, shape=shape, dtype=dtype,
                 allocator=self.allocator)
 
-    def from_numpy(self, array: np.ndarray):
+    def from_numpy(self, array: Union[np.ndarray, _ScalarLike]):
         import pyopencl.array as cl_array
         return cl_array.to_device(self.queue, array, allocator=self.allocator)
 
diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py
index 0e12b92..744fcbc 100644
--- a/arraycontext/impl/pytato/__init__.py
+++ b/arraycontext/impl/pytato/__init__.py
@@ -41,7 +41,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
 """
 
-from arraycontext.context import ArrayContext
+from arraycontext.context import ArrayContext, _ScalarLike
 import numpy as np
 from typing import Any, Callable, Union, Sequence, TYPE_CHECKING
 from pytools.tag import Tag
@@ -98,10 +98,10 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
         import pytato as pt
         return pt.zeros(shape, dtype)
 
-    def from_numpy(self, np_array: np.ndarray):
+    def from_numpy(self, array: Union[np.ndarray, _ScalarLike]):
         import pytato as pt
         import pyopencl.array as cla
-        cl_array = cla.to_device(self.queue, np_array)
+        cl_array = cla.to_device(self.queue, array)
         return pt.make_data_wrapper(cl_array)
 
     def to_numpy(self, array):
-- 
GitLab