From 07d3d4b6b804e201000b37d9a37184bdbdafea7e Mon Sep 17 00:00:00 2001
From: Kaushik Kulkarni <kaushikcfd@gmail.com>
Date: Sun, 20 Jun 2021 17:05:53 -0500
Subject: [PATCH] type of keys in rec_keyed_map_array_container is more relaxed

---
 arraycontext/container/traversal.py | 32 +++++++++--------------------
 1 file changed, 10 insertions(+), 22 deletions(-)

diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py
index ab8198c..1dc0d3b 100644
--- a/arraycontext/container/traversal.py
+++ b/arraycontext/container/traversal.py
@@ -234,9 +234,8 @@ def multimapped_over_array_containers(
     return wrapper
 
 
-def keyed_map_array_container(
-        f: Callable[[Union[str, int], Any], Any],
-        ary: ArrayContainerT) -> ArrayContainerT:
+def keyed_map_array_container(f: Callable[[Any, Any], Any],
+                              ary: ArrayContainerT) -> ArrayContainerT:
     r"""Applies *f* to all components of an :class:`ArrayContainer`.
 
     Works similarly to :func:`map_array_container`, but *f* also takes an
@@ -255,21 +254,11 @@ def keyed_map_array_container(
         raise ValueError("Not an array-container, i.e. unknown key to pass.")
 
 
-def _tuple_if_not_tuple(x: Any) -> Tuple[Union[str, int], ...]:
-    if not isinstance(x, tuple):
-        assert isinstance(x, (str, int))
-        return x,
-    else:
-        assert all(isinstance(el, (str, int))
-                   for el in x)
-        return x
-
-
-def _keyed_map_array_container_impl(
-        f: Callable[[Tuple[Union[str, int], ...], Any], Any],
-        ary: ArrayContainerT, *,
-        leaf_cls: Optional[type] = None,
-        recursive: bool = False) -> ArrayContainerT:
+def _keyed_map_array_container_impl(f: Callable[[Tuple[Any, ...], Any], Any],
+                                    ary: ArrayContainerT,
+                                    *,
+                                    leaf_cls: Optional[type] = None,
+                                    recursive: bool = False) -> ArrayContainerT:
     """Helper for :func:`rec_keyed_map_array_container`.
 
     :param leaf_cls: class on which we call *f* directly. This is mostly
@@ -284,7 +273,7 @@ def _keyed_map_array_container_impl(
         elif is_array_container(_ary):
 
             return deserialize_container(_ary, [
-                    (key, frec(keys+_tuple_if_not_tuple(key), subary))
+                    (key, frec(keys+(key,), subary))
                     for key, subary in serialize_container(_ary)
                     ])
         else:
@@ -294,9 +283,8 @@ def _keyed_map_array_container_impl(
     return rec((), ary)
 
 
-def rec_keyed_map_array_container(
-        f: Callable[[Tuple[Union[str, int], ...], Any], Any],
-        ary: ArrayContainerT) -> ArrayContainerT:
+def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any],
+                                  ary: ArrayContainerT) -> ArrayContainerT:
     """
     Works similarly to :func:`rec_map_array_container`, except that *f* also
     takes in a traversal path to the leaf array. The traversal path argument is
-- 
GitLab