diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 28910150feb4d0124cd633a00731cb2ff39abf37..7d724b843eb3b2b9799ad61d97bd66c7328b239e 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -30,15 +30,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Dict, Sequence, Union +from typing import Any, Dict import numpy as np import loopy as lp -from pytools.tag import Tag +from pytools.tag import ToTagSetConvertible from arraycontext.container.traversal import rec_map_array_container, with_array_context -from arraycontext.context import ArrayContext +from arraycontext.context import ( + ArrayContext, + ArrayOrContainerOrScalar, + ArrayOrContainerOrScalarT, + NumpyOrContainerOrScalar, +) class NumpyNonObjectArrayMetaclass(type): @@ -56,7 +61,7 @@ class NumpyArrayContext(ArrayContext): .. automethod:: __init__ """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._loopy_transform_cache: \ Dict[lp.TranslationUnit, lp.TranslationUnit] = {} @@ -72,18 +77,14 @@ class NumpyArrayContext(ArrayContext): def clone(self): return type(self)() - def empty(self, shape, dtype): - return np.empty(shape, dtype=dtype) - - def zeros(self, shape, dtype): - return np.zeros(shape, dtype) - - def from_numpy(self, np_array: np.ndarray): - # Uh oh... - return np_array + def from_numpy(self, + array: NumpyOrContainerOrScalar + ) -> ArrayOrContainerOrScalar: + return array - def to_numpy(self, array): - # Uh oh... + def to_numpy(self, + array: ArrayOrContainerOrScalar + ) -> NumpyOrContainerOrScalar: return array def call_loopy(self, t_unit, **kwargs): @@ -119,11 +120,16 @@ class NumpyArrayContext(ArrayContext): "transform_loopy_program. Sub-classes are supposed " "to implement it.") - def tag(self, tags: Union[Sequence[Tag], Tag], array): + def tag(self, + tags: ToTagSetConvertible, + array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: # Numpy doesn't support tagging return array - def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): + def tag_axis(self, + iaxis: int, tags: ToTagSetConvertible, + array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: + # Numpy doesn't support tagging return array def einsum(self, spec, *args, arg_names=None, tagged=()):