diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 2469334de47c47718d3c6d8ce1063de88362232c..c9b1282d898ae7b3466c19c0d7108d6c4a40bb1e 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -62,13 +62,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from numbers import Number from typing import Any, Callable, Iterable, List, Optional, Union, Tuple from functools import update_wrapper, partial, singledispatch import numpy as np -from arraycontext.context import ArrayContext, DeviceArray +from arraycontext.context import ArrayContext, DeviceArray, _ScalarLike from arraycontext.container import ( ArrayT, ContainerT, ArrayOrContainerT, NotAnArrayContainerError, serialize_container, deserialize_container) @@ -738,14 +737,14 @@ def unflatten( # {{{ numpy conversion def from_numpy( - ary: Union[np.ndarray, np.generic, Number], + ary: Union[np.ndarray, _ScalarLike], actx: ArrayContext) -> ArrayOrContainerT: """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer` to the base array type of :class:`~arraycontext.ArrayContext`. The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`. """ - def _from_numpy_with_check(subary: Union[np.ndarray, np.generic, Number]) \ + def _from_numpy_with_check(subary: Union[np.ndarray, _ScalarLike]) \ -> ArrayOrContainerT: if isinstance(subary, np.ndarray) or np.isscalar(subary): return actx.from_numpy(subary)