From 427d25c7f015839dd910c9d406640f0bdba8d426 Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Tue, 11 Jan 2022 12:53:29 -0600
Subject: [PATCH] Tighten type info on {rec_,}keyed_map_array_container

---
 arraycontext/container/traversal.py | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index b29dc86..2469334 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -70,7 +70,7 @@ import numpy as np
 
 from arraycontext.context import ArrayContext, DeviceArray
 from arraycontext.container import (
-        ContainerT, ArrayOrContainerT, NotAnArrayContainerError,
+        ArrayT, ContainerT, ArrayOrContainerT, NotAnArrayContainerError,
         serialize_container, deserialize_container)
 
 
@@ -327,8 +327,11 @@ def multimapped_over_array_containers(
 
 # {{{ keyed array container traversal
 
-def keyed_map_array_container(f: Callable[[Any, Any], Any],
-                              ary: ArrayOrContainerT) -> ArrayOrContainerT:
+def keyed_map_array_container(
+        f: Callable[
+            [Any, ArrayOrContainerT],
+            ArrayOrContainerT],
+        ary: ArrayOrContainerT) -> ArrayOrContainerT:
     r"""Applies *f* to all components of an :class:`ArrayContainer`.
 
     Works similarly to :func:`map_array_container`, but *f* also takes an
@@ -350,8 +353,9 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any],
             ])
 
 
-def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any],
-                                  ary: ArrayOrContainerT) -> ArrayOrContainerT:
+def rec_keyed_map_array_container(
+        f: Callable[[Tuple[Any, ...], ArrayT], ArrayT],
+        ary: ArrayOrContainerT) -> ArrayOrContainerT:
     """
     Works similarly to :func:`rec_map_array_container`, except that *f* also
     takes in a traversal path to the leaf array. The traversal path argument is
-- 
GitLab