From 0569eab090b8caf30a9e1a00acd6636b72f79620 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Wed, 29 Jun 2022 15:18:58 -0500 Subject: [PATCH] Clean up array and container type annotations Co-authored-by: Michael Campbell <mtcampbe@illinois.edu> --- arraycontext/__init__.py | 22 +++- arraycontext/container/__init__.py | 83 ++++++++----- arraycontext/container/traversal.py | 163 ++++++++++++++----------- arraycontext/context.py | 101 +++++++++------ arraycontext/impl/jax/__init__.py | 4 +- arraycontext/impl/pyopencl/__init__.py | 4 +- arraycontext/impl/pytato/__init__.py | 10 +- arraycontext/impl/pytato/compile.py | 4 +- test/test_arraycontext.py | 9 +- 9 files changed, 245 insertions(+), 155 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index b8287cf..06e0b96 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -30,8 +30,14 @@ THE SOFTWARE. import sys from .context import ( + ArrayContext, + + Scalar, ScalarLike, Array, ArrayT, - ArrayContext, Scalar, tag_axes) + ArrayOrContainer, ArrayOrContainerT, + ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, + + tag_axes) from .transform_metadata import (CommonSubexpressionTag, ElementwiseMapKernelTag) @@ -40,8 +46,8 @@ from .transform_metadata import (CommonSubexpressionTag, from .metadata import _FirstAxisIsElementsTag from .container import ( - ArrayOrContainerT as ArrayOrContainer, ArrayOrContainerT, - ArrayContainer, NotAnArrayContainerError, + ArrayContainer, ArrayContainerT, + NotAnArrayContainerError, is_array_container, is_array_container_type, get_container_context_opt, get_container_context_recursively, get_container_context_recursively_opt, @@ -81,14 +87,18 @@ from .loopy import make_loopy_program __all__ = ( + "ArrayContext", "Scalar", "Array", + "Scalar", "ScalarLike", "Array", "ArrayT", - "ArrayContext", "Scalar", "tag_axes", + "ArrayOrContainer", "ArrayOrContainerT", + "ArrayOrContainerOrScalar", "ArrayOrContainerOrScalarT", + "tag_axes", "CommonSubexpressionTag", "ElementwiseMapKernelTag", - "ArrayOrContainer", "ArrayOrContainerT", - "ArrayContainer", "NotAnArrayContainerError", + "ArrayContainer", "ArrayContainerT", + "NotAnArrayContainerError", "is_array_container", "is_array_container_type", "get_container_context_opt", "get_container_context_recursively_opt", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 789dd29..71bccee 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -3,22 +3,10 @@ """ .. currentmodule:: arraycontext -.. class:: ArrayT - :canonical: arraycontext.container.ArrayT - - :class:`~typing.TypeVar` for arrays. - -.. class:: ContainerT - :canonical: arraycontext.container.ContainerT - - :class:`~typing.TypeVar` for array container-like objects. - -.. class:: ArrayOrContainerT - :canonical: arraycontext.container.ArrayOrContainerT - - :class:`~typing.TypeVar` for arrays or array container-like objects. - .. autoclass:: ArrayContainer +.. class:: ArrayContainerT + + A type variable with a lower bound of :class:`ArrayContainer`. .. autoexception:: NotAnArrayContainerError @@ -38,8 +26,23 @@ Context retrieval --------------------------------------------------------- .. autofunction:: register_multivector_as_array_container + +.. currentmodule:: arraycontext.container + +Canonical locations for type annotations +---------------------------------------- + +.. class:: ArrayContainerT + + :canonical: arraycontext.ArrayContainerT + +.. class:: ArrayOrContainerT + + :canonical: arraycontext.ArrayOrContainerT """ +from __future__ import annotations + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -67,22 +70,24 @@ THE SOFTWARE. from functools import singledispatch from arraycontext.context import ArrayContext -from typing import Any, Iterable, Tuple, TypeVar, Optional, Union, TYPE_CHECKING +from typing import Any, Iterable, Tuple, Optional, TypeVar, Protocol, TYPE_CHECKING import numpy as np -ArrayT = TypeVar("ArrayT") -ContainerT = TypeVar("ContainerT") -ArrayOrContainerT = Union[ArrayT, ContainerT] +# For use in singledispatch type annotations, because sphinx can't figure out +# what 'np' is. +import numpy + if TYPE_CHECKING: from pymbolic.geometric_algebra import MultiVector + from arraycontext import ArrayOrContainer # {{{ ArrayContainer -class ArrayContainer: - r""" - A generic container for the array type supported by the +class ArrayContainer(Protocol): + """ + A protocol for generic containers of the array type supported by the :class:`ArrayContext`. The functionality required for the container to operated is supplied via @@ -113,17 +118,31 @@ class ArrayContainer: .. note:: - This class is used in type annotation. Inheriting from it confers no - special meaning or behavior. + This class is used in type annotation and as a marker of array container + attributes for :func:`~arraycontext.dataclass_array_container`. + As a protocol, it is not intended as a superclass. """ + # Array containers do not need to have any particular features, so this + # protocol is deliberately empty. + + # This *is* used as a type annotation in dataclasses that are processed + # by dataclass_array_container, where it's used to recognize attributes + # that are container-typed. + + pass + + +ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer) + class NotAnArrayContainerError(TypeError): """:class:`TypeError` subclass raised when an array container is expected.""" @singledispatch -def serialize_container(ary: Any) -> Iterable[Tuple[Any, Any]]: +def serialize_container( + ary: ArrayContainer) -> Iterable[Tuple[Any, ArrayOrContainer]]: r"""Serialize the array container into an iterable over its components. The order of the components and their identifiers are entirely under @@ -149,7 +168,9 @@ def serialize_container(ary: Any) -> Iterable[Tuple[Any, Any]]: @singledispatch -def deserialize_container(template: Any, iterable: Iterable[Tuple[Any, Any]]) -> Any: +def deserialize_container( + template: ArrayContainerT, + iterable: Iterable[Tuple[Any, Any]]) -> ArrayContainerT: """Deserialize an iterable into an array container. :param template: an instance of an existing object that @@ -214,7 +235,8 @@ def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]: # {{{ object arrays as array containers @serialize_container.register(np.ndarray) -def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]: +def _serialize_ndarray_container( + ary: numpy.ndarray) -> Iterable[Tuple[Any, ArrayOrContainer]]: if ary.dtype.char != "O": raise NotAnArrayContainerError( f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'") @@ -232,9 +254,10 @@ def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]: @deserialize_container.register(np.ndarray) -def _deserialize_ndarray_container( - template: np.ndarray, - iterable: Iterable[Tuple[Any, Any]]) -> np.ndarray: +# https://github.com/python/mypy/issues/13040 +def _deserialize_ndarray_container( # type: ignore[misc] + template: numpy.ndarray, + iterable: Iterable[Tuple[Any, ArrayOrContainer]]) -> numpy.ndarray: # disallow subclasses assert type(template) is np.ndarray assert template.dtype.char == "O" diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 2911340..6c858fe 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -39,6 +39,8 @@ Algebraic operations .. autofunction:: outer """ +from __future__ import annotations + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ @@ -63,15 +65,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Callable, Iterable, List, Optional, Union, Tuple +from typing import Any, Callable, Iterable, List, Optional, Union, Tuple, cast from functools import update_wrapper, partial, singledispatch from warnings import warn import numpy as np -from arraycontext.context import ArrayContext, Array, _ScalarLike +from arraycontext.context import ( + ArrayT, ArrayOrContainer, ArrayOrContainerT, + ArrayOrContainerOrScalar, ScalarLike, + ArrayContext, Array +) from arraycontext.container import ( - ArrayT, ContainerT, ArrayOrContainerT, NotAnArrayContainerError, + NotAnArrayContainerError, + ArrayContainer, serialize_container, deserialize_container, get_container_context_recursively_opt) @@ -79,10 +86,10 @@ from arraycontext.container import ( # {{{ array container traversal helpers def _map_array_container_impl( - f: Callable[[Any], Any], - ary: ArrayOrContainerT, *, + f: Callable[[ArrayOrContainer], ArrayOrContainer], + ary: ArrayOrContainer, *, leaf_cls: Optional[type] = None, - recursive: bool = False) -> ArrayOrContainerT: + recursive: bool = False) -> ArrayOrContainer: """Helper for :func:`rec_map_array_container`. :param leaf_cls: class on which we call *f* directly. This is mostly @@ -90,7 +97,7 @@ def _map_array_container_impl( specific container classes. By default, the recursion is stopped when a non-:class:`ArrayContainer` class is encountered. """ - def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: + def rec(_ary: ArrayOrContainer) -> ArrayOrContainer: if type(_ary) is leaf_cls: # type(ary) is never None return f(_ary) @@ -110,9 +117,10 @@ def _map_array_container_impl( def _multimap_array_container_impl( f: Callable[..., Any], *args: Any, - reduce_func: Callable[[ContainerT, Iterable[Tuple[Any, Any]]], Any] = None, + reduce_func: Optional[Callable[ + [ArrayContainer, Iterable[Tuple[Any, Any]]], Any]] = None, leaf_cls: Optional[type] = None, - recursive: bool = False) -> ArrayOrContainerT: + recursive: bool = False) -> ArrayOrContainer: """Helper for :func:`rec_multimap_array_container`. :param leaf_cls: class on which we call *f* directly. This is mostly @@ -198,7 +206,7 @@ def _multimap_array_container_impl( return f(*new_args) update_wrapper(wrapper, f) - template_ary: ContainerT = args[container_indices[0]] + template_ary: ArrayContainer = args[container_indices[0]] return _map_array_container_impl( wrapper, template_ary, leaf_cls=leaf_cls, recursive=recursive) @@ -221,7 +229,7 @@ def _multimap_array_container_impl( def map_array_container( f: Callable[[Any], Any], - ary: ArrayOrContainerT) -> ArrayOrContainerT: + ary: ArrayOrContainer) -> ArrayOrContainer: r"""Applies *f* to all components of an :class:`ArrayContainer`. Works similarly to :func:`~pytools.obj_array.obj_array_vectorize`, but @@ -259,8 +267,8 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any: def rec_map_array_container( f: Callable[[Any], Any], - ary: ArrayOrContainerT, - leaf_class: Optional[type] = None) -> ArrayOrContainerT: + ary: ArrayOrContainer, + leaf_class: Optional[type] = None) -> ArrayOrContainer: r"""Applies *f* recursively to an :class:`ArrayContainer`. For a non-recursive version see :func:`map_array_container`. @@ -272,15 +280,15 @@ def rec_map_array_container( def mapped_over_array_containers( - f: Optional[Callable[[Any], Any]] = None, + f: Optional[Callable[[ArrayOrContainer], ArrayOrContainer]] = None, leaf_class: Optional[type] = None) -> Union[ - Callable[[ArrayOrContainerT], ArrayOrContainerT], + Callable[[ArrayOrContainer], ArrayOrContainer], Callable[ [Callable[[Any], Any]], - Callable[[ArrayOrContainerT], ArrayOrContainerT]]]: + Callable[[ArrayOrContainer], ArrayOrContainer]]]: """Decorator around :func:`rec_map_array_container`.""" - def decorator(g: Callable[[Any], Any]) -> Callable[ - [ArrayOrContainerT], ArrayOrContainerT]: + def decorator(g: Callable[[ArrayOrContainer], ArrayOrContainer]) -> Callable[ + [ArrayOrContainer], ArrayOrContainer]: wrapper = partial(rec_map_array_container, g, leaf_class=leaf_class) update_wrapper(wrapper, g) return wrapper @@ -329,11 +337,14 @@ def multimapped_over_array_containers( # {{{ keyed array container traversal +KeyType = Any + + def keyed_map_array_container( f: Callable[ - [Any, ArrayOrContainerT], - ArrayOrContainerT], - ary: ArrayOrContainerT) -> ArrayOrContainerT: + [KeyType, ArrayOrContainer], + ArrayOrContainer], + ary: ArrayOrContainer) -> ArrayOrContainer: r"""Applies *f* to all components of an :class:`ArrayContainer`. Works similarly to :func:`map_array_container`, but *f* also takes an @@ -356,8 +367,8 @@ def keyed_map_array_container( def rec_keyed_map_array_container( - f: Callable[[Tuple[Any, ...], ArrayT], ArrayT], - ary: ArrayOrContainerT) -> ArrayOrContainerT: + f: Callable[[Tuple[KeyType, ...], ArrayT], ArrayT], + ary: ArrayOrContainer) -> ArrayOrContainer: """ Works similarly to :func:`rec_map_array_container`, except that *f* also takes in a traversal path to the leaf array. The traversal path argument is @@ -370,7 +381,7 @@ def rec_keyed_map_array_container( try: iterable = serialize_container(_ary) except NotAnArrayContainerError: - return f(keys, _ary) + return cast(ArrayOrContainerT, f(keys, cast(ArrayT, _ary))) else: return deserialize_container(_ary, [ (key, rec(keys + (key,), subary)) for key, subary in iterable @@ -409,7 +420,7 @@ def map_reduce_array_container( def multimap_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[..., Any], - *args: Any) -> "Array": + *args: Any) -> ArrayOrContainer: r"""Perform a map-reduce over multiple array containers. :param reduce_func: callable used to reduce over the components of any @@ -421,7 +432,9 @@ def multimap_reduce_array_container( """ # NOTE: this wrapper matches the signature of `deserialize_container` # to make plugging into `_multimap_array_container_impl` easier - def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any: + def _reduce_wrapper( + ary: ArrayContainer, iterable: Iterable[Tuple[Any, Any]] + ) -> Array: return reduce_func([subary for _, subary in iterable]) return _multimap_array_container_impl( @@ -432,8 +445,8 @@ def multimap_reduce_array_container( def rec_map_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[[Any], Any], - ary: ArrayOrContainerT, - leaf_class: Optional[type] = None) -> "Array": + ary: ArrayOrContainer, + leaf_class: Optional[type] = None) -> ArrayOrContainer: """Perform a map-reduce over array containers recursively. :param reduce_func: callable used to reduce over the components of *ary* @@ -491,7 +504,7 @@ def rec_multimap_reduce_array_container( reduce_func: Callable[[Iterable[Any]], Any], map_func: Callable[..., Any], *args: Any, - leaf_class: Optional[type] = None) -> "Array": + leaf_class: Optional[type] = None) -> ArrayOrContainer: r"""Perform a map-reduce over multiple array containers recursively. :param reduce_func: callable used to reduce over the components of any @@ -509,7 +522,8 @@ def rec_multimap_reduce_array_container( """ # NOTE: this wrapper matches the signature of `deserialize_container` # to make plugging into `_multimap_array_container_impl` easier - def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any: + def _reduce_wrapper( + ary: ArrayContainer, iterable: Iterable[Tuple[Any, Any]]) -> Any: return reduce_func([subary for _, subary in iterable]) return _multimap_array_container_impl( @@ -623,7 +637,7 @@ def with_array_context(ary: ArrayOrContainerT, # {{{ flatten / unflatten def flatten( - ary: ArrayOrContainerT, actx: ArrayContext, *, + ary: ArrayOrContainer, actx: ArrayContext, *, leaf_class: Optional[type] = None, ) -> Any: """Convert all arrays in the :class:`~arraycontext.ArrayContainer` @@ -646,32 +660,35 @@ def flatten( """ common_dtype = None - def _flatten(subary: ArrayOrContainerT) -> List[Any]: + def _flatten(subary: ArrayOrContainer) -> List[Array]: nonlocal common_dtype try: iterable = serialize_container(subary) except NotAnArrayContainerError: + subary_c = cast(Array, subary) + if common_dtype is None: - common_dtype = subary.dtype + common_dtype = subary_c.dtype - if subary.dtype != common_dtype: + if subary_c.dtype != common_dtype: raise ValueError("arrays in container have different dtypes: " - f"got {subary.dtype}, expected {common_dtype}") + f"got {subary_c.dtype}, expected {common_dtype}") try: - flat_subary = actx.np.ravel(subary, order="C") + flat_subary = actx.np.ravel(subary_c, order="C") except ValueError as exc: # NOTE: we can't do much if the array context fails to ravel, # since it is the one responsible for the actual memory layout - if hasattr(subary, "strides"): - strides_msg = f" and strides {subary.strides}" + if hasattr(subary_c, "strides"): + # Mypy has a point: nobody promised a strides attr. + strides_msg = f" and strides {subary_c.strides}" # type: ignore[attr-defined] # noqa: E501 else: strides_msg = "" raise NotImplementedError( f"'{type(actx).__name__}.np.ravel' failed to reshape " - f"an array with shape {subary.shape}{strides_msg}. " + f"an array with shape {subary_c.shape}{strides_msg}. " "This functionality needs to be implemented by the " "array context.") from exc @@ -683,7 +700,7 @@ def flatten( return result - def _flatten_without_leaf_class(subary: ArrayOrContainerT) -> Any: + def _flatten_without_leaf_class(subary: ArrayOrContainer) -> Any: result = _flatten(subary) if len(result) == 1: @@ -691,7 +708,7 @@ def flatten( else: return actx.np.concatenate(result) - def _flatten_with_leaf_class(subary: ArrayOrContainerT) -> Any: + def _flatten_with_leaf_class(subary: ArrayOrContainer) -> Any: if type(subary) is leaf_class: return _flatten_without_leaf_class(subary) @@ -731,46 +748,48 @@ def unflatten( offset = 0 common_dtype = None - def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT: + def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer: nonlocal offset, common_dtype try: iterable = serialize_container(template_subary) except NotAnArrayContainerError: + template_subary_c = cast(Array, template_subary) + # {{{ validate subary - if (offset + template_subary.size) > ary.size: + if (offset + template_subary_c.size) > ary.size: raise ValueError("'template' and 'ary' sizes do not match: " "'template' is too large") if strict: - if template_subary.dtype != ary.dtype: + if template_subary_c.dtype != ary.dtype: raise ValueError("'template' dtype does not match 'ary': " - f"got {template_subary.dtype}, expected {ary.dtype}") + f"got {template_subary_c.dtype}, expected {ary.dtype}") else: # NOTE: still require that *template* has a uniform dtype if common_dtype is None: - common_dtype = template_subary.dtype + common_dtype = template_subary_c.dtype else: - if common_dtype != template_subary.dtype: + if common_dtype != template_subary_c.dtype: raise ValueError("arrays in 'template' have different " - f"dtypes: got {template_subary.dtype}, but " + f"dtypes: got {template_subary_c.dtype}, but " f"expected {common_dtype}.") # }}} # {{{ reshape - flat_subary = ary[offset:offset + template_subary.size] + flat_subary = ary[offset:offset + template_subary_c.size] try: subary = actx.np.reshape(flat_subary, - template_subary.shape, order="C") + template_subary_c.shape, order="C") except ValueError as exc: # NOTE: we can't do much if the array context fails to reshape, # since it is the one responsible for the actual memory layout raise NotImplementedError( f"'{type(actx).__name__}.np.reshape' failed to reshape " - f"the flat array into shape {template_subary.shape}. " + f"the flat array into shape {template_subary_c.shape}. " "This functionality needs to be implemented by the " "array context.") from exc @@ -782,21 +801,23 @@ def unflatten( # Checking strides for 0 sized arrays is ill-defined # since they cannot be indexed if ( - template_subary.strides != subary.strides - and template_subary.size != 0 + # Mypy has a point: nobody promised a .strides attribute. + template_subary_c.strides != subary.strides # type: ignore[attr-defined] # noqa: E501 + and template_subary_c.size != 0 ): raise ValueError( - f"strides do not match template: got {subary.strides}, " - f"expected {template_subary.strides}") + # Mypy has a point: nobody promised a .strides attribute. + f"strides do not match template: got {subary.strides}, " # type: ignore[attr-defined] # noqa: E501 + f"expected {template_subary_c.strides}") # }}} - offset += template_subary.size + offset += template_subary_c.size return subary else: return deserialize_container(template_subary, [ - (key, _unflatten(isubary)) for key, isubary in iterable - ]) + (key, _unflatten(isubary)) for key, isubary in iterable + ]) if not isinstance(ary, actx.array_types): raise TypeError("'ary' does not have a type supported by the provided " @@ -813,11 +834,11 @@ def unflatten( raise ValueError("'template' and 'ary' sizes do not match: " "'ary' is too large") - return result + return cast(ArrayOrContainerT, result) def flat_size_and_dtype( - ary: ArrayOrContainerT) -> "Tuple[int, Optional[np.dtype[Any]]]": + ary: ArrayOrContainer) -> "Tuple[int, Optional[np.dtype[Any]]]": """ :returns: a tuple ``(size, dtype)`` that would be the length and :class:`numpy.dtype` of the one-dimensional array returned by @@ -825,20 +846,22 @@ def flat_size_and_dtype( """ common_dtype = None - def _flat_size(subary: ArrayOrContainerT) -> int: + def _flat_size(subary: ArrayOrContainer) -> int: nonlocal common_dtype try: iterable = serialize_container(subary) except NotAnArrayContainerError: + subary_c = cast(Array, subary) + if common_dtype is None: - common_dtype = subary.dtype + common_dtype = subary_c.dtype - if subary.dtype != common_dtype: + if subary_c.dtype != common_dtype: raise ValueError("arrays in container have different dtypes: " - f"got {subary.dtype}, expected {common_dtype}") + f"got {subary_c.dtype}, expected {common_dtype}") - return subary.size + return subary_c.size else: return sum(_flat_size(isubary) for _, isubary in iterable) @@ -851,15 +874,15 @@ def flat_size_and_dtype( # {{{ numpy conversion def from_numpy( - ary: Union[np.ndarray, _ScalarLike], - actx: ArrayContext) -> ArrayOrContainerT: + ary: Union[np.ndarray, ScalarLike], + actx: ArrayContext) -> ArrayOrContainerOrScalar: """Convert all :mod:`numpy` arrays in the :class:`~arraycontext.ArrayContainer` to the base array type of :class:`~arraycontext.ArrayContext`. The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`. """ - def _from_numpy_with_check(subary: Union[np.ndarray, _ScalarLike]) \ - -> ArrayOrContainerT: + def _from_numpy_with_check(subary: Union[np.ndarray, ScalarLike]) \ + -> ArrayOrContainerOrScalar: if isinstance(subary, np.ndarray) or np.isscalar(subary): return actx.from_numpy(subary) else: @@ -868,7 +891,7 @@ def from_numpy( return rec_map_array_container(_from_numpy_with_check, ary) -def to_numpy(ary: ArrayOrContainerT, actx: ArrayContext) -> Any: +def to_numpy(ary: ArrayOrContainer, actx: ArrayContext) -> Any: """Convert all arrays in the :class:`~arraycontext.ArrayContainer` to :mod:`numpy` using the provided :class:`~arraycontext.ArrayContext` *actx*. diff --git a/arraycontext/context.py b/arraycontext/context.py index b6278e2..74d5863 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -72,16 +72,43 @@ actual array contexts: an array expression that has been built up by the user (using, e.g. :func:`pytato.generate_loopy`). -The interface of an array context ---------------------------------- .. currentmodule:: arraycontext -.. autoclass:: Array -.. autoclass:: Scalar +The interface of an array context +--------------------------------- + .. autoclass:: ArrayContext + .. autofunction:: tag_axes +Types and Type Variables for Arrays and Containers +-------------------------------------------------- + +.. autoclass:: Array + +.. class:: ArrayT + + A type variable with a lower bound of :class:`Array`. + +.. class:: ScalarLike + + A type annotation for scalar types commonly usable with arrays. + +See also :class:`ArrayContainer` and :class:`ArrayOrContainerT`. + +.. class:: ArrayOrContainer + +.. class:: ArrayOrContainerT + + A type variable with a lower bound of :class:`ArrayOrContainer`. + +.. class:: ArrayOrContainerOrScalar + +.. class:: ArrayOrContainerOrScalarT + + A type variable with a lower bound of :class:`ArrayOrContainerOrScalar`. + Internal typing helpers (do not import) --------------------------------------- @@ -91,11 +118,21 @@ This is only here because the documentation tool wants it. .. class:: SelfType +Canonical locations for type annotations +---------------------------------------- + .. class:: ArrayT - A type variable, with a lower bound of :class:`Array`. -""" + :canonical: arraycontext.ArrayT + +.. class:: ArrayOrContainerT + + :canonical: arraycontext.ArrayOrContainerT +.. class:: ArrayOrContainerOrScalarT + + :canonical: arraycontext.ArrayOrContainerOrScalarT +""" __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -132,11 +169,12 @@ from pytools.tag import ToTagSetConvertible if TYPE_CHECKING: import loopy + from arraycontext.container import ArrayContainer # {{{ typing -_ScalarLike = Union[int, float, complex, np.generic] +ScalarLike = Union[int, float, complex, np.generic] try: from typing import Protocol @@ -154,6 +192,7 @@ class Array(Protocol): supported types see :attr:`ArrayContext.array_types`. .. attribute:: shape + .. attribute:: size .. attribute:: dtype """ @@ -162,28 +201,7 @@ class Array(Protocol): ... @property - def dtype(self) -> "np.dtype[Any]": - ... - - -ArrayT = TypeVar("ArrayT", bound=Array) - - -class Scalar(Protocol): - """A :class:`~typing.Protocol` for the scalar type supported by - :class:`ArrayContext`. - - In :mod:`numpy` terminology, this is just an array with a shape of ``()``. - - This is meant to aid in typing annotations. For a explicit list of - supported types see :attr:`ArrayContext.array_types`. - - .. attribute:: shape - .. attribute:: dtype - """ - - @property - def shape(self) -> Tuple[()]: + def size(self) -> int: ... @property @@ -191,6 +209,18 @@ class Scalar(Protocol): ... +# deprecated, use ScalarLike instead +Scalar = ScalarLike + + +ArrayT = TypeVar("ArrayT", bound=Array) +ArrayOrContainer = Union[Array, "ArrayContainer"] +ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer) +ArrayOrContainerOrScalar = Union[Array, "ArrayContainer", ScalarLike] +ArrayOrContainerOrScalarT = TypeVar( + "ArrayOrContainerOrScalarT", + bound=ArrayOrContainerOrScalar) + # }}} @@ -273,8 +303,8 @@ class ArrayContext(ABC): @abstractmethod def from_numpy(self, - array: Union["np.ndarray[Any, Any]", _ScalarLike] - ) -> Union[Array, _ScalarLike]: + array: Union["np.ndarray[Any, Any]", ScalarLike] + ) -> Union[Array, ScalarLike]: r""" :returns: the :class:`numpy.ndarray` *array* converted to the array context's array type. The returned array will be @@ -284,8 +314,8 @@ class ArrayContext(ABC): @abstractmethod def to_numpy(self, - array: Union[Array, _ScalarLike] - ) -> Union["np.ndarray[Any, Any]", _ScalarLike]: + array: Union[Array, ScalarLike] + ) -> Union["np.ndarray[Any, Any]", ScalarLike]: r""" :returns: *array*, an array recognized by the context, converted to a :class:`numpy.ndarray`. *array* must be @@ -293,6 +323,7 @@ class ArrayContext(ABC): """ pass + @abstractmethod def call_loopy(self, program: "loopy.TranslationUnit", **kwargs: Any) -> Dict[str, Array]: @@ -308,7 +339,7 @@ class ArrayContext(ABC): """ @abstractmethod - def freeze(self, array: Array) -> Array: + def freeze(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: """Return a version of the context-defined array *array* that is 'frozen', i.e. suitable for long-term storage and reuse. Frozen arrays do not support arithmetic. For example, in the context of @@ -324,7 +355,7 @@ class ArrayContext(ABC): """ @abstractmethod - def thaw(self, array: Array) -> Array: + def thaw(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: """Take a 'frozen' array and return a new array representing the data in *array* that is able to perform arithmetic and other operations, using the execution resources of this context. In the context of diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index 154300a..81076ae 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -31,7 +31,7 @@ import numpy as np from typing import Union, Callable, Any from pytools.tag import ToTagSetConvertible -from arraycontext.context import ArrayContext, _ScalarLike +from arraycontext.context import ArrayContext, ScalarLike from arraycontext.container.traversal import (with_array_context, rec_map_array_container) @@ -70,7 +70,7 @@ class EagerJAXArrayContext(ArrayContext): import jax.numpy as jnp return jnp.zeros(shape=shape, dtype=dtype) - def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): + def from_numpy(self, array: Union[np.ndarray, ScalarLike]): import jax return jax.device_put(array) diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index de0b27b..1467246 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -35,7 +35,7 @@ import numpy as np from pytools.tag import ToTagSetConvertible -from arraycontext.context import ArrayContext, _ScalarLike +from arraycontext.context import ArrayContext, ScalarLike from arraycontext.container.traversal import (rec_map_array_container, with_array_context) @@ -167,7 +167,7 @@ class PyOpenCLArrayContext(ArrayContext): allocator=self.allocator), axes=None, tags=frozenset()) - def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): + def from_numpy(self, array: Union[np.ndarray, ScalarLike]): import pyopencl.array as cl_array from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array return to_tagged_cl_array(cl_array diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 0970ffb..ed3bff8 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -43,7 +43,7 @@ THE SOFTWARE. """ import sys -from arraycontext.context import ArrayContext, _ScalarLike +from arraycontext.context import ArrayContext, ScalarLike from arraycontext.container.traversal import (rec_map_array_container, with_array_context) from arraycontext.metadata import NameHint @@ -237,7 +237,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): def clone(self): return type(self)(self.queue, self.allocator) - def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): + def from_numpy(self, array: Union[np.ndarray, ScalarLike]): import pytato as pt import pyopencl.array as cla cl_array = cla.to_device(self.queue, array) @@ -288,7 +288,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): import pyopencl.array as cla import loopy as lp - from arraycontext.container import ArrayT + from arraycontext.context import ArrayT from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.utils import (_normalize_pt_expr, get_cl_axes_from_pt_axes) @@ -524,7 +524,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): def clone(self): return type(self)() - def from_numpy(self, array: Union[np.ndarray, _ScalarLike]): + def from_numpy(self, array: Union[np.ndarray, ScalarLike]): import jax import pytato as pt return pt.make_data_wrapper(jax.device_put(array)) @@ -548,7 +548,7 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): import pytato as pt from jax.numpy import DeviceArray - from arraycontext.container import ArrayT + from arraycontext.context import ArrayT from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.compile import _ary_container_key_stringifier diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 46404e2..9e92adf 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -29,8 +29,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from arraycontext.container import (ArrayContainer, is_array_container_type, - ArrayT) +from arraycontext.context import ArrayT +from arraycontext.container import ArrayContainer, is_array_container_type from arraycontext.impl.pytato import (_BasePytatoArrayContext, PytatoJAXArrayContext, PytatoPyOpenCLArrayContext) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 928f446..e89663c 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -174,7 +174,8 @@ def _serialize_dof_container(ary: DOFArray): @deserialize_container.register(DOFArray) -def _deserialize_dof_container( +# https://github.com/python/mypy/issues/13040 +def _deserialize_dof_container( # type: ignore[misc] template, iterable): def _raise_index_inconsistency(i, stream_i): raise ValueError( @@ -189,7 +190,8 @@ def _deserialize_dof_container( @with_array_context.register(DOFArray) -def _with_actx_dofarray(ary, actx): +# https://github.com/python/mypy/issues/13040 +def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: ignore[misc] # noqa: E501 return type(ary)(actx, ary.data) # }}} @@ -1188,7 +1190,8 @@ class Velocity2D: @with_array_context.register(Velocity2D) -def _with_actx_velocity_2d(ary, actx): +# https://github.com/python/mypy/issues/13040 +def _with_actx_velocity_2d(ary, actx): # type: ignore[misc] return type(ary)(ary.u, ary.v, actx) -- GitLab