diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 07c154464c1e06d475fe88a89078f83bf5328372..b29dc86df66461ab9ca214e347954f3742baae31 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -62,6 +62,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from numbers import Number from typing import Any, Callable, Iterable, List, Optional, Union, Tuple from functools import update_wrapper, partial, singledispatch @@ -732,13 +733,16 @@ def unflatten( # {{{ numpy conversion -def from_numpy(ary: Any, actx: ArrayContext) -> Any: +def from_numpy( + ary: Union[np.ndarray, np.generic, Number], + actx: ArrayContext) -> ArrayOrContainerT: """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer` to the base array type of :class:`~arraycontext.ArrayContext`. The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`. """ - def _from_numpy_with_check(subary: Any) -> Any: + def _from_numpy_with_check(subary: Union[np.ndarray, np.generic, Number]) \ + -> ArrayOrContainerT: if isinstance(subary, np.ndarray) or np.isscalar(subary): return actx.from_numpy(subary) else: @@ -747,7 +751,7 @@ def from_numpy(ary: Any, actx: ArrayContext) -> Any: return rec_map_array_container(_from_numpy_with_check, ary) -def to_numpy(ary: Any, actx: ArrayContext) -> Any: +def to_numpy(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: """Convert all arrays in the :class:`~arraycontext.ArrayContainer` to :mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*. diff --git a/arraycontext/context.py b/arraycontext/context.py index fa705136a8ceb7d224a75b1ddee33e6cc5408682..aa3054d94c320d0252d1e2705c00e68558698f1b 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -125,6 +125,7 @@ from pytools.tag import Tag DeviceArray = Any DeviceScalar = Any +_ScalarLike = Union[int, float, complex, np.generic] # {{{ ArrayContext @@ -197,7 +198,7 @@ class ArrayContext(ABC): return self.zeros(shape=ary.shape, dtype=ary.dtype) @abstractmethod - def from_numpy(self, array: np.ndarray): + def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): r""" :returns: the :class:`numpy.ndarray` *array* converted to the array context's array type. The returned array will be diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 585a99eac5279fdeb06f7ef76166c58744837c1d..7336aa9b83cfc3eb524c8da682c6f898cd06abe2 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -34,7 +34,7 @@ import numpy as np from pytools.tag import Tag -from arraycontext.context import ArrayContext +from arraycontext.context import ArrayContext, _ScalarLike if TYPE_CHECKING: @@ -156,7 +156,7 @@ class PyOpenCLArrayContext(ArrayContext): return cl_array.zeros(self.queue, shape=shape, dtype=dtype, allocator=self.allocator) - def from_numpy(self, array: np.ndarray): + def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): import pyopencl.array as cl_array return cl_array.to_device(self.queue, array, allocator=self.allocator) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 0e12b92f7b3fddcea5c869afa5277c41774f6281..744fcbc4ec791133d19d641dbb1d43c3447894f2 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -41,7 +41,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from arraycontext.context import ArrayContext +from arraycontext.context import ArrayContext, _ScalarLike import numpy as np from typing import Any, Callable, Union, Sequence, TYPE_CHECKING from pytools.tag import Tag @@ -98,10 +98,10 @@ class PytatoPyOpenCLArrayContext(ArrayContext): import pytato as pt return pt.zeros(shape, dtype) - def from_numpy(self, np_array: np.ndarray): + def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): import pytato as pt import pyopencl.array as cla - cl_array = cla.to_device(self.queue, np_array) + cl_array = cla.to_device(self.queue, array) return pt.make_data_wrapper(cl_array) def to_numpy(self, array):