diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 0f349be59175f1c1606eb3fbd52c17b822410da2..ab8198ca962948946566f41e7889e2dcc399a3e2 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -255,6 +255,16 @@ def keyed_map_array_container( raise ValueError("Not an array-container, i.e. unknown key to pass.") +def _tuple_if_not_tuple(x: Any) -> Tuple[Union[str, int], ...]: + if not isinstance(x, tuple): + assert isinstance(x, (str, int)) + return x, + else: + assert all(isinstance(el, (str, int)) + for el in x) + return x + + def _keyed_map_array_container_impl( f: Callable[[Tuple[Union[str, int], ...], Any], Any], ary: ArrayContainerT, *, @@ -272,14 +282,6 @@ def _keyed_map_array_container_impl( if type(_ary) is leaf_cls: # type(ary) is never None return f(keys, _ary) elif is_array_container(_ary): - def _tuple_if_not_tuple(x: Any) -> Tuple[Union[str, int], ...]: - if not isinstance(x, tuple): - assert isinstance(x, (str, int)) - return x, - else: - assert all(isinstance(el, (str, int)) - for el in x) - return x return deserialize_container(_ary, [ (key, frec(keys+_tuple_if_not_tuple(key), subary))