diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index ab8198ca962948946566f41e7889e2dcc399a3e2..1dc0d3b1193bea0df29baba96dada561f10f10e9 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -234,9 +234,8 @@ def multimapped_over_array_containers( return wrapper -def keyed_map_array_container( - f: Callable[[Union[str, int], Any], Any], - ary: ArrayContainerT) -> ArrayContainerT: +def keyed_map_array_container(f: Callable[[Any, Any], Any], + ary: ArrayContainerT) -> ArrayContainerT: r"""Applies *f* to all components of an :class:`ArrayContainer`. Works similarly to :func:`map_array_container`, but *f* also takes an @@ -255,21 +254,11 @@ def keyed_map_array_container( raise ValueError("Not an array-container, i.e. unknown key to pass.") -def _tuple_if_not_tuple(x: Any) -> Tuple[Union[str, int], ...]: - if not isinstance(x, tuple): - assert isinstance(x, (str, int)) - return x, - else: - assert all(isinstance(el, (str, int)) - for el in x) - return x - - -def _keyed_map_array_container_impl( - f: Callable[[Tuple[Union[str, int], ...], Any], Any], - ary: ArrayContainerT, *, - leaf_cls: Optional[type] = None, - recursive: bool = False) -> ArrayContainerT: +def _keyed_map_array_container_impl(f: Callable[[Tuple[Any, ...], Any], Any], + ary: ArrayContainerT, + *, + leaf_cls: Optional[type] = None, + recursive: bool = False) -> ArrayContainerT: """Helper for :func:`rec_keyed_map_array_container`. :param leaf_cls: class on which we call *f* directly. This is mostly @@ -284,7 +273,7 @@ def _keyed_map_array_container_impl( elif is_array_container(_ary): return deserialize_container(_ary, [ - (key, frec(keys+_tuple_if_not_tuple(key), subary)) + (key, frec(keys+(key,), subary)) for key, subary in serialize_container(_ary) ]) else: @@ -294,9 +283,8 @@ def _keyed_map_array_container_impl( return rec((), ary) -def rec_keyed_map_array_container( - f: Callable[[Tuple[Union[str, int], ...], Any], Any], - ary: ArrayContainerT) -> ArrayContainerT: +def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], + ary: ArrayContainerT) -> ArrayContainerT: """ Works similarly to :func:`rec_map_array_container`, except that *f* also takes in a traversal path to the leaf array. The traversal path argument is