diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index b29dc86df66461ab9ca214e347954f3742baae31..2469334de47c47718d3c6d8ce1063de88362232c 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -70,7 +70,7 @@ import numpy as np from arraycontext.context import ArrayContext, DeviceArray from arraycontext.container import ( - ContainerT, ArrayOrContainerT, NotAnArrayContainerError, + ArrayT, ContainerT, ArrayOrContainerT, NotAnArrayContainerError, serialize_container, deserialize_container) @@ -327,8 +327,11 @@ def multimapped_over_array_containers( # {{{ keyed array container traversal -def keyed_map_array_container(f: Callable[[Any, Any], Any], - ary: ArrayOrContainerT) -> ArrayOrContainerT: +def keyed_map_array_container( + f: Callable[ + [Any, ArrayOrContainerT], + ArrayOrContainerT], + ary: ArrayOrContainerT) -> ArrayOrContainerT: r"""Applies *f* to all components of an :class:`ArrayContainer`. Works similarly to :func:`map_array_container`, but *f* also takes an @@ -350,8 +353,9 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any], ]) -def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], - ary: ArrayOrContainerT) -> ArrayOrContainerT: +def rec_keyed_map_array_container( + f: Callable[[Tuple[Any, ...], ArrayT], ArrayT], + ary: ArrayOrContainerT) -> ArrayOrContainerT: """ Works similarly to :func:`rec_map_array_container`, except that *f* also takes in a traversal path to the leaf array. The traversal path argument is