From e30e01237dd1bbcc30b938335ee0840e6decd84d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 14:17:50 -0500 Subject: [PATCH] Container serialization: iterable -> sequence, plus type aliases --- arraycontext/__init__.py | 6 +++- arraycontext/container/__init__.py | 50 ++++++++++++++++++++++------- arraycontext/container/traversal.py | 10 +++--- pyproject.toml | 3 ++ 4 files changed, 50 insertions(+), 19 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 2f2640d..e8e6e9f 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -32,6 +32,8 @@ from .container import ( ArrayContainer, ArrayContainerT, NotAnArrayContainerError, + SerializationKey, + SerializedContainer, deserialize_container, get_container_context_opt, get_container_context_recursively, @@ -113,6 +115,8 @@ __all__ = ( "PytestPyOpenCLArrayContextFactory", "Scalar", "ScalarLike", + "SerializationKey", + "SerializedContainer", "dataclass_array_container", "deserialize_container", "flat_size_and_dtype", @@ -148,7 +152,7 @@ __all__ = ( "with_array_context", "with_container_arithmetic", "with_container_arithmetic" - ) +) # {{{ deprecation handling diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 53506a0..38a2341 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -12,6 +12,9 @@ Serialization/deserialization ----------------------------- + +.. autoclass:: SerializationKey +.. autoclass:: SerializedContainer .. autofunction:: is_array_container_type .. autofunction:: serialize_container .. autofunction:: deserialize_container @@ -39,6 +42,14 @@ Canonical locations for type annotations .. class:: ArrayOrContainerT :canonical: arraycontext.ArrayOrContainerT + +.. class:: SerializationKey + + :canonical: arraycontext.SerializationKey + +.. class:: SerializedContainer + + :canonical: arraycontext.SerializedContainer """ from __future__ import annotations @@ -69,12 +80,23 @@ THE SOFTWARE. """ from functools import singledispatch -from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Hashable, + Iterable, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, +) # For use in singledispatch type annotations, because sphinx can't figure out # what 'np' is. import numpy import numpy as np +from typing_extensions import TypeAlias from arraycontext.context import ArrayContext @@ -142,23 +164,27 @@ class NotAnArrayContainerError(TypeError): """:class:`TypeError` subclass raised when an array container is expected.""" +SerializationKey: TypeAlias = Hashable +SerializedContainer: TypeAlias = Sequence[Tuple[SerializationKey, "ArrayOrContainer"]] + + @singledispatch def serialize_container( - ary: ArrayContainer) -> Iterable[Tuple[Any, ArrayOrContainer]]: - r"""Serialize the array container into an iterable over its components. + ary: ArrayContainer) -> SerializedContainer: + r"""Serialize the array container into a sequence over its components. The order of the components and their identifiers are entirely under the control of the container class. However, the order is required to be deterministic, i.e. two calls to :func:`serialize_container` on array containers of the same types with the same number of - sub-arrays must result in an iterable with the keys in the same + sub-arrays must result in a sequence with the keys in the same order. If *ary* is mutable, the serialization function is not required to ensure that the serialization result reflects the array state at the time of the call to :func:`serialize_container`. - :returns: an :class:`Iterable` of 2-tuples where the first + :returns: a :class:`Sequence` of 2-tuples where the first entry is an identifier for the component and the second entry is an array-like component of the :class:`ArrayContainer`. Components can themselves be :class:`ArrayContainer`\ s, allowing @@ -172,13 +198,13 @@ def serialize_container( @singledispatch def deserialize_container( template: ArrayContainerT, - iterable: Iterable[Tuple[Any, Any]]) -> ArrayContainerT: - """Deserialize an iterable into an array container. + serialized: SerializedContainer) -> ArrayContainerT: + """Deserialize a sequence into an array container following a *template*. :param template: an instance of an existing object that can be used to aid in the deserialization. For a similar choice see :attr:`~numpy.class.__array_finalize__`. - :param iterable: an iterable that mirrors the output of + :param serialized: a sequence that mirrors the output of :meth:`serialize_container`. """ raise NotAnArrayContainerError( @@ -242,7 +268,7 @@ def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]: @serialize_container.register(np.ndarray) def _serialize_ndarray_container( - ary: numpy.ndarray) -> Iterable[Tuple[Any, ArrayOrContainer]]: + ary: numpy.ndarray) -> SerializedContainer: if ary.dtype.char != "O": raise NotAnArrayContainerError( f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'") @@ -256,20 +282,20 @@ def _serialize_ndarray_container( for j in range(ary.shape[1]) ] else: - return np.ndenumerate(ary) + return list(np.ndenumerate(ary)) @deserialize_container.register(np.ndarray) # https://github.com/python/mypy/issues/13040 def _deserialize_ndarray_container( # type: ignore[misc] template: numpy.ndarray, - iterable: Iterable[Tuple[Any, ArrayOrContainer]]) -> numpy.ndarray: + serialized: SerializedContainer) -> numpy.ndarray: # disallow subclasses assert type(template) is np.ndarray assert template.dtype.char == "O" result = type(template)(template.shape, dtype=object) - for i, subary in iterable: + for i, subary in serialized: result[i] = subary return result diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 4a60a8f..31b3bcf 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -77,6 +77,7 @@ import numpy as np from arraycontext.container import ( ArrayContainer, NotAnArrayContainerError, + SerializationKey, deserialize_container, get_container_context_recursively_opt, serialize_container, @@ -373,12 +374,9 @@ def multimapped_over_array_containers( # {{{ keyed array container traversal -KeyType = Any - - def keyed_map_array_container( f: Callable[ - [KeyType, ArrayOrContainer], + [SerializationKey, ArrayOrContainer], ArrayOrContainer], ary: ArrayOrContainer) -> ArrayOrContainer: r"""Applies *f* to all components of an :class:`ArrayContainer`. @@ -403,7 +401,7 @@ def keyed_map_array_container( def rec_keyed_map_array_container( - f: Callable[[Tuple[KeyType, ...], ArrayT], ArrayT], + f: Callable[[Tuple[SerializationKey, ...], ArrayT], ArrayT], ary: ArrayOrContainer) -> ArrayOrContainer: """ Works similarly to :func:`rec_map_array_container`, except that *f* also @@ -412,7 +410,7 @@ def rec_keyed_map_array_container( the current array. """ - def rec(keys: Tuple[Union[str, int], ...], + def rec(keys: Tuple[SerializationKey, ...], _ary: ArrayOrContainerT) -> ArrayOrContainerT: try: iterable = serialize_container(_ary) diff --git a/pyproject.toml b/pyproject.toml index ca64c70..d971ae2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,9 @@ dependencies = [ "immutabledict>=4.1", "numpy", "pytools>=2024.1.3", + + # for TypeAlias + "typing-extensions>=4; python_version<'3.10'", ] [project.optional-dependencies] -- GitLab