From b2e28015e82902103a29095d53a1df099a2affc7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 11 Jan 2022 11:09:27 -0600 Subject: [PATCH] Tighten type information for from_numpy --- arraycontext/container/traversal.py | 10 +++++++--- arraycontext/context.py | 3 ++- arraycontext/impl/pyopencl/__init__.py | 4 ++-- arraycontext/impl/pytato/__init__.py | 6 +++--- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 07c1544..b29dc86 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -62,6 +62,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from numbers import Number from typing import Any, Callable, Iterable, List, Optional, Union, Tuple from functools import update_wrapper, partial, singledispatch @@ -732,13 +733,16 @@ def unflatten( # {{{ numpy conversion -def from_numpy(ary: Any, actx: ArrayContext) -> Any: +def from_numpy( + ary: Union[np.ndarray, np.generic, Number], + actx: ArrayContext) -> ArrayOrContainerT: """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer` to the base array type of :class:`~arraycontext.ArrayContext`. The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`. """ - def _from_numpy_with_check(subary: Any) -> Any: + def _from_numpy_with_check(subary: Union[np.ndarray, np.generic, Number]) \ + -> ArrayOrContainerT: if isinstance(subary, np.ndarray) or np.isscalar(subary): return actx.from_numpy(subary) else: @@ -747,7 +751,7 @@ def from_numpy(ary: Any, actx: ArrayContext) -> Any: return rec_map_array_container(_from_numpy_with_check, ary) -def to_numpy(ary: Any, actx: ArrayContext) -> Any: +def to_numpy(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: """Convert all arrays in the :class:`~arraycontext.ArrayContainer` to :mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*. diff --git a/arraycontext/context.py b/arraycontext/context.py index fa70513..aa3054d 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -125,6 +125,7 @@ from pytools.tag import Tag DeviceArray = Any DeviceScalar = Any +_ScalarLike = Union[int, float, complex, np.generic] # {{{ ArrayContext @@ -197,7 +198,7 @@ class ArrayContext(ABC): return self.zeros(shape=ary.shape, dtype=ary.dtype) @abstractmethod - def from_numpy(self, array: np.ndarray): + def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): r""" :returns: the :class:`numpy.ndarray` *array* converted to the array context's array type. The returned array will be diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 585a99e..7336aa9 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -34,7 +34,7 @@ import numpy as np from pytools.tag import Tag -from arraycontext.context import ArrayContext +from arraycontext.context import ArrayContext, _ScalarLike if TYPE_CHECKING: @@ -156,7 +156,7 @@ class PyOpenCLArrayContext(ArrayContext): return cl_array.zeros(self.queue, shape=shape, dtype=dtype, allocator=self.allocator) - def from_numpy(self, array: np.ndarray): + def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): import pyopencl.array as cl_array return cl_array.to_device(self.queue, array, allocator=self.allocator) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 0e12b92..744fcbc 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -41,7 +41,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from arraycontext.context import ArrayContext +from arraycontext.context import ArrayContext, _ScalarLike import numpy as np from typing import Any, Callable, Union, Sequence, TYPE_CHECKING from pytools.tag import Tag @@ -98,10 +98,10 @@ class PytatoPyOpenCLArrayContext(ArrayContext): import pytato as pt return pt.zeros(shape, dtype) - def from_numpy(self, np_array: np.ndarray): + def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): import pytato as pt import pyopencl.array as cla - cl_array = cla.to_device(self.queue, np_array) + cl_array = cla.to_device(self.queue, array) return pt.make_data_wrapper(cl_array) def to_numpy(self, array): -- GitLab