From 50511fe79e3a23407ee362f6fa098945a6132dff Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Wed, 11 Jan 2023 17:11:30 +0200 Subject: [PATCH] add a dumb container tree stringifier --- arraycontext/__init__.py | 5 ++-- arraycontext/container/traversal.py | 29 +++++++++++++++++++ test/test_utils.py | 45 +++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 2 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 00e542d..b01b991 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -44,8 +44,8 @@ from .container.traversal import ( multimap_array_container, multimap_reduce_array_container, multimapped_over_array_containers, outer, rec_map_array_container, rec_map_reduce_array_container, rec_multimap_array_container, - rec_multimap_reduce_array_container, thaw, to_numpy, unflatten, - with_array_context) + rec_multimap_reduce_array_container, stringify_array_container_tree, thaw, + to_numpy, unflatten, with_array_context) from .context import ( Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayT, Scalar, ScalarLike, @@ -85,6 +85,7 @@ __all__ = ( "with_container_arithmetic", "dataclass_array_container", + "stringify_array_container_tree", "map_array_container", "multimap_array_container", "rec_map_array_container", "rec_multimap_array_container", "mapped_over_array_containers", diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 048f6a1..b59fe79 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -13,6 +13,8 @@ .. autofunction:: rec_map_reduce_array_container .. autofunction:: rec_multimap_reduce_array_container +.. autofunction:: stringify_array_container_tree + Traversing decorators ~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: mapped_over_array_containers @@ -224,6 +226,33 @@ def _multimap_array_container_impl( # {{{ array container traversal +def stringify_array_container_tree(ary: ArrayOrContainer) -> str: + """ + :returns: a string for an ASCII tree representation of the array container, + similar to `asciitree `__. + """ + def rec(lines: List[str], ary_: ArrayOrContainerT, level: int) -> None: + try: + iterable = serialize_container(ary_) + except NotAnArrayContainerError: + pass + else: + for key, subary in iterable: + key = f"{key} ({type(subary).__name__})" + if level == 0: + indent = "" + else: + indent = f" | {' ' * 4 * (level - 1)}" + + lines.append(f"{indent} +-- {key}") + rec(lines, subary, level + 1) + + lines = [f"root ({type(ary).__name__})"] + rec(lines, ary, 0) + + return "\n".join(lines) + + def map_array_container( f: Callable[[Any], Any], ary: ArrayOrContainer) -> ArrayOrContainer: diff --git a/test/test_utils.py b/test/test_utils.py index 94f6b0a..4bb49c8 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -144,6 +144,51 @@ def test_dataclass_container_unions() -> None: # }}} +# {{{ test_stringify_array_container_tree + + +def test_stringify_array_container_tree() -> None: + from dataclasses import dataclass + + from arraycontext import ( + Array, dataclass_array_container, stringify_array_container_tree) + + @dataclass_array_container + @dataclass(frozen=True) + class ArrayWrapper: + ary: Array + + @dataclass_array_container + @dataclass(frozen=True) + class SomeContainer: + points: Array + radius: float + centers: ArrayWrapper + + @dataclass_array_container + @dataclass(frozen=True) + class SomeOtherContainer: + disk: SomeContainer + circle: SomeContainer + has_disk: bool + norm_type: str + extent: float + + rng = np.random.default_rng(seed=42) + a = ArrayWrapper(ary=rng.random(10)) + d = SomeContainer(points=rng.random((2, 10)), radius=rng.random(), centers=a) + c = SomeContainer(points=rng.random((2, 10)), radius=rng.random(), centers=a) + ary = SomeOtherContainer( + disk=d, circle=c, + has_disk=True, + norm_type="l2", + extent=1) + + logger.info("\n%s", stringify_array_container_tree(ary)) + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: -- GitLab