diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 2f2640dc9ecfbfc3f52d0c3e18f59a5c7cb00eec..e8e6e9f345888d19d830bc933e3d2a94206f32c3 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -32,6 +32,8 @@ from .container import ( ArrayContainer, ArrayContainerT, NotAnArrayContainerError, + SerializationKey, + SerializedContainer, deserialize_container, get_container_context_opt, get_container_context_recursively, @@ -113,6 +115,8 @@ __all__ = ( "PytestPyOpenCLArrayContextFactory", "Scalar", "ScalarLike", + "SerializationKey", + "SerializedContainer", "dataclass_array_container", "deserialize_container", "flat_size_and_dtype", @@ -148,7 +152,7 @@ __all__ = ( "with_array_context", "with_container_arithmetic", "with_container_arithmetic" - ) +) # {{{ deprecation handling diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 53506a0fc60b7ab95634a2c269fd19ed964f9f83..38a23412561270f5854f9d08ef0a9003e98f5bcb 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -12,6 +12,9 @@ Serialization/deserialization ----------------------------- + +.. autoclass:: SerializationKey +.. autoclass:: SerializedContainer .. autofunction:: is_array_container_type .. autofunction:: serialize_container .. autofunction:: deserialize_container @@ -39,6 +42,14 @@ Canonical locations for type annotations .. class:: ArrayOrContainerT :canonical: arraycontext.ArrayOrContainerT + +.. class:: SerializationKey + + :canonical: arraycontext.SerializationKey + +.. class:: SerializedContainer + + :canonical: arraycontext.SerializedContainer """ from __future__ import annotations @@ -69,12 +80,23 @@ THE SOFTWARE. """ from functools import singledispatch -from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Hashable, + Iterable, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, +) # For use in singledispatch type annotations, because sphinx can't figure out # what 'np' is. import numpy import numpy as np +from typing_extensions import TypeAlias from arraycontext.context import ArrayContext @@ -142,23 +164,27 @@ class NotAnArrayContainerError(TypeError): """:class:`TypeError` subclass raised when an array container is expected.""" +SerializationKey: TypeAlias = Hashable +SerializedContainer: TypeAlias = Sequence[Tuple[SerializationKey, "ArrayOrContainer"]] + + @singledispatch def serialize_container( - ary: ArrayContainer) -> Iterable[Tuple[Any, ArrayOrContainer]]: - r"""Serialize the array container into an iterable over its components. + ary: ArrayContainer) -> SerializedContainer: + r"""Serialize the array container into a sequence over its components. The order of the components and their identifiers are entirely under the control of the container class. However, the order is required to be deterministic, i.e. two calls to :func:`serialize_container` on array containers of the same types with the same number of - sub-arrays must result in an iterable with the keys in the same + sub-arrays must result in a sequence with the keys in the same order. If *ary* is mutable, the serialization function is not required to ensure that the serialization result reflects the array state at the time of the call to :func:`serialize_container`. - :returns: an :class:`Iterable` of 2-tuples where the first + :returns: a :class:`Sequence` of 2-tuples where the first entry is an identifier for the component and the second entry is an array-like component of the :class:`ArrayContainer`. Components can themselves be :class:`ArrayContainer`\ s, allowing @@ -172,13 +198,13 @@ def serialize_container( @singledispatch def deserialize_container( template: ArrayContainerT, - iterable: Iterable[Tuple[Any, Any]]) -> ArrayContainerT: - """Deserialize an iterable into an array container. + serialized: SerializedContainer) -> ArrayContainerT: + """Deserialize a sequence into an array container following a *template*. :param template: an instance of an existing object that can be used to aid in the deserialization. For a similar choice see :attr:`~numpy.class.__array_finalize__`. - :param iterable: an iterable that mirrors the output of + :param serialized: a sequence that mirrors the output of :meth:`serialize_container`. """ raise NotAnArrayContainerError( @@ -242,7 +268,7 @@ def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]: @serialize_container.register(np.ndarray) def _serialize_ndarray_container( - ary: numpy.ndarray) -> Iterable[Tuple[Any, ArrayOrContainer]]: + ary: numpy.ndarray) -> SerializedContainer: if ary.dtype.char != "O": raise NotAnArrayContainerError( f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'") @@ -256,20 +282,20 @@ def _serialize_ndarray_container( for j in range(ary.shape[1]) ] else: - return np.ndenumerate(ary) + return list(np.ndenumerate(ary)) @deserialize_container.register(np.ndarray) # https://github.com/python/mypy/issues/13040 def _deserialize_ndarray_container( # type: ignore[misc] template: numpy.ndarray, - iterable: Iterable[Tuple[Any, ArrayOrContainer]]) -> numpy.ndarray: + serialized: SerializedContainer) -> numpy.ndarray: # disallow subclasses assert type(template) is np.ndarray assert template.dtype.char == "O" result = type(template)(template.shape, dtype=object) - for i, subary in iterable: + for i, subary in serialized: result[i] = subary return result diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 4a60a8f95a83d88cea3df8b11ff952354abe4f2d..31b3bcf5e145c4ed8bfe11c70c41698369294f6c 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -77,6 +77,7 @@ import numpy as np from arraycontext.container import ( ArrayContainer, NotAnArrayContainerError, + SerializationKey, deserialize_container, get_container_context_recursively_opt, serialize_container, @@ -373,12 +374,9 @@ def multimapped_over_array_containers( # {{{ keyed array container traversal -KeyType = Any - - def keyed_map_array_container( f: Callable[ - [KeyType, ArrayOrContainer], + [SerializationKey, ArrayOrContainer], ArrayOrContainer], ary: ArrayOrContainer) -> ArrayOrContainer: r"""Applies *f* to all components of an :class:`ArrayContainer`. @@ -403,7 +401,7 @@ def keyed_map_array_container( def rec_keyed_map_array_container( - f: Callable[[Tuple[KeyType, ...], ArrayT], ArrayT], + f: Callable[[Tuple[SerializationKey, ...], ArrayT], ArrayT], ary: ArrayOrContainer) -> ArrayOrContainer: """ Works similarly to :func:`rec_map_array_container`, except that *f* also @@ -412,7 +410,7 @@ def rec_keyed_map_array_container( the current array. """ - def rec(keys: Tuple[Union[str, int], ...], + def rec(keys: Tuple[SerializationKey, ...], _ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) diff --git a/pyproject.toml b/pyproject.toml index ca64c70a9b3f5097a422c3d703c001a7a240d3e2..d971ae20b1cbef3139734109c1054cfa3e8221b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,9 @@ dependencies = [ "immutabledict>=4.1", "numpy", "pytools>=2024.1.3", + + # for TypeAlias + "typing-extensions>=4; python_version<'3.10'", ] [project.optional-dependencies]