diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 00e542dec0502c3a6754a18c78978940052c0f9c..b01b9917864052ca11f754dbb53bef30130da402 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -44,8 +44,8 @@ from .container.traversal import ( multimap_array_container, multimap_reduce_array_container, multimapped_over_array_containers, outer, rec_map_array_container, rec_map_reduce_array_container, rec_multimap_array_container, - rec_multimap_reduce_array_container, thaw, to_numpy, unflatten, - with_array_context) + rec_multimap_reduce_array_container, stringify_array_container_tree, thaw, + to_numpy, unflatten, with_array_context) from .context import ( Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayT, Scalar, ScalarLike, @@ -85,6 +85,7 @@ __all__ = ( "with_container_arithmetic", "dataclass_array_container", + "stringify_array_container_tree", "map_array_container", "multimap_array_container", "rec_map_array_container", "rec_multimap_array_container", "mapped_over_array_containers", diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 048f6a19e9e6e88484edc80ae48ee610337d9d73..b59fe79451e4f0f702c61dc65bc044653232b4e7 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -13,6 +13,8 @@ .. autofunction:: rec_map_reduce_array_container .. autofunction:: rec_multimap_reduce_array_container +.. autofunction:: stringify_array_container_tree + Traversing decorators ~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: mapped_over_array_containers @@ -224,6 +226,33 @@ def _multimap_array_container_impl( # {{{ array container traversal +def stringify_array_container_tree(ary: ArrayOrContainer) -> str: + """ + :returns: a string for an ASCII tree representation of the array container, + similar to `asciitree `__. + """ + def rec(lines: List[str], ary_: ArrayOrContainerT, level: int) -> None: + try: + iterable = serialize_container(ary_) + except NotAnArrayContainerError: + pass + else: + for key, subary in iterable: + key = f"{key} ({type(subary).__name__})" + if level == 0: + indent = "" + else: + indent = f" | {' ' * 4 * (level - 1)}" + + lines.append(f"{indent} +-- {key}") + rec(lines, subary, level + 1) + + lines = [f"root ({type(ary).__name__})"] + rec(lines, ary, 0) + + return "\n".join(lines) + + def map_array_container( f: Callable[[Any], Any], ary: ArrayOrContainer) -> ArrayOrContainer: diff --git a/test/test_utils.py b/test/test_utils.py index 94f6b0a31888564658612212e0daae15dd065ae0..4bb49c87ec183487ace560a7058ee53f5ae1eea4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -144,6 +144,51 @@ def test_dataclass_container_unions() -> None: # }}} +# {{{ test_stringify_array_container_tree + + +def test_stringify_array_container_tree() -> None: + from dataclasses import dataclass + + from arraycontext import ( + Array, dataclass_array_container, stringify_array_container_tree) + + @dataclass_array_container + @dataclass(frozen=True) + class ArrayWrapper: + ary: Array + + @dataclass_array_container + @dataclass(frozen=True) + class SomeContainer: + points: Array + radius: float + centers: ArrayWrapper + + @dataclass_array_container + @dataclass(frozen=True) + class SomeOtherContainer: + disk: SomeContainer + circle: SomeContainer + has_disk: bool + norm_type: str + extent: float + + rng = np.random.default_rng(seed=42) + a = ArrayWrapper(ary=rng.random(10)) + d = SomeContainer(points=rng.random((2, 10)), radius=rng.random(), centers=a) + c = SomeContainer(points=rng.random((2, 10)), radius=rng.random(), centers=a) + ary = SomeOtherContainer( + disk=d, circle=c, + has_disk=True, + norm_type="l2", + extent=1) + + logger.info("\n%s", stringify_array_container_tree(ary)) + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: