diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 0bd0a1daae7c6c093e661c93152956e3b2df539b..c3653bbf4f47f694f7231e572ce38107be1a1a5e 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -256,45 +256,28 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any], raise ValueError("Not an array-container, i.e. unknown key to pass.") -def _keyed_map_array_container_impl(f: Callable[[Tuple[Any, ...], Any], Any], - ary: ArrayContainerT, - *, - leaf_cls: Optional[type] = None, - recursive: bool = False) -> ArrayContainerT: - """Helper for :func:`rec_keyed_map_array_container`. - - :param leaf_cls: class on which we call *f* directly. This is mostly - useful in the recursive setting, where it can stop the recursion on - specific container classes. By default, the recursion is stopped when - a non-:class:`ArrayContainer` class is encountered. +def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], + ary: ArrayContainerT) -> ArrayContainerT: """ + Works similarly to :func:`rec_map_array_container`, except that *f* also + takes in a traversal path to the leaf array. The traversal path argument is + passed in as a tuple of identifiers of the arrays traversed before reaching + the current array. + """ + def rec(keys: Tuple[Union[str, int], ...], _ary: ArrayContainerT) -> ArrayContainerT: - if type(_ary) is leaf_cls: # type(ary) is never None - return f(keys, _ary) - elif is_array_container(_ary): + if is_array_container(_ary): return deserialize_container(_ary, [ - (key, frec(keys+(key,), subary)) + (key, rec(keys+(key,), subary)) for key, subary in serialize_container(_ary) ]) else: return f(keys, _ary) - frec = rec if recursive else f return rec((), ary) - -def rec_keyed_map_array_container(f: Callable[[Tuple[Any, ...], Any], Any], - ary: ArrayContainerT) -> ArrayContainerT: - """ - Works similarly to :func:`rec_map_array_container`, except that *f* also - takes in a traversal path to the leaf array. The traversal path argument is - passed in as a tuple of identifiers of the arrays traversed before reaching - the current array. - """ - return _keyed_map_array_container_impl(f, ary, recursive=True) - # }}}