From 427d25c7f015839dd910c9d406640f0bdba8d426 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner <inform@tiker.net> Date: Tue, 11 Jan 2022 12:53:29 -0600 Subject: [PATCH] Tighten type info on {rec_,}keyed_map_array_container --- arraycontext/container/traversal.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index b29dc86..2469334 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -70,7 +70,7 @@ import numpy as np from arraycontext.context import ArrayContext, DeviceArray from arraycontext.container import ( - ContainerT, ArrayOrContainerT, NotAnArrayContainerError, + ArrayT, ContainerT, ArrayOrContainerT, NotAnArrayContainerError, serialize_container, deserialize_container) @@ -327,8 +327,11 @@ def multimapped_over_array_containers( # {{{ keyed array container traversal -def keyed_map_array_container(f: Callable[[Any, Any], Any], - ary: ArrayOrContainerT) -> ArrayOrContainerT: +def keyed_map_array_container( + f: Callable[ + [Any, ArrayOrContainerT], + ArrayOrContainerT], + ary: ArrayOrContainerT) -> ArrayOrContainerT: r"""Applies *f* to all components of an :class:`ArrayContainer`. Works similarly to :func:`map_array_container`, but *f* also takes an @@ -350,8 +353,9 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any], ]) -def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], - ary: ArrayOrContainerT) -> ArrayOrContainerT: +def rec_keyed_map_array_container( + f: Callable[[Tuple[Any, ...], ArrayT], ArrayT], + ary: ArrayOrContainerT) -> ArrayOrContainerT: """ Works similarly to :func:`rec_map_array_container`, except that *f* also takes in a traversal path to the leaf array. The traversal path argument is -- GitLab