From 8902f9fc39eb4dde14b7cfae5d23f412b5e2f228 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 22 May 2023 17:10:29 -0500 Subject: [PATCH] split visualization.py Use separate file for emitting dot and emitting asciidag. --- pytato/visualization/__init__.py | 16 +++ pytato/visualization/ascii.py | 124 ++++++++++++++++++ .../dot.py} | 89 ------------- 3 files changed, 140 insertions(+), 89 deletions(-) create mode 100644 pytato/visualization/__init__.py create mode 100644 pytato/visualization/ascii.py rename pytato/{visualization.py => visualization/dot.py} (86%) diff --git a/pytato/visualization/__init__.py b/pytato/visualization/__init__.py new file mode 100644 index 0000000..af9220a --- /dev/null +++ b/pytato/visualization/__init__.py @@ -0,0 +1,16 @@ +""" +.. currentmodule:: pytato + +.. automodule:: pytato.visualization.dot +.. automodule:: pytato.visualization.ascii +""" + +from .dot import get_dot_graph, show_dot_graph, get_dot_graph_from_partition +from .ascii import get_ascii_graph, show_ascii_graph + + +__all__ = [ + "get_dot_graph", "show_dot_graph", "get_dot_graph_from_partition", + + "get_ascii_graph", "show_ascii_graph", +] diff --git a/pytato/visualization/ascii.py b/pytato/visualization/ascii.py new file mode 100644 index 0000000..e417a9b --- /dev/null +++ b/pytato/visualization/ascii.py @@ -0,0 +1,124 @@ +""" +.. currentmodule:: pytato + +.. autofunction:: get_ascii_graph +.. autofunction:: show_ascii_graph +""" +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Union, List, Dict +from pytato.transform import ArrayOrNames +from pytato.array import Array, DictOfNamedArrays, InputArgumentBase +from pytato.visualization.dot import ArrayToDotNodeInfoMapper +from pytato.codegen import normalize_outputs +from pytools import UniqueNameGenerator + + +# {{{ Show ASCII representation of DAG + +def get_ascii_graph(result: Union[Array, DictOfNamedArrays], + use_color: bool = True) -> str: + """Return a string representing the computation of *result* + using the `asciidag `_ package. + + :arg result: Outputs of the computation (cf. + :func:`pytato.generate_loopy`). + :arg use_color: Colorized output + """ + outputs: DictOfNamedArrays = normalize_outputs(result) + del result + + mapper = ArrayToDotNodeInfoMapper() + for elem in outputs._data.values(): + mapper(elem) + + nodes = mapper.nodes + + input_arrays: List[Array] = [] + internal_arrays: List[ArrayOrNames] = [] + array_to_id: Dict[ArrayOrNames, str] = {} + + id_gen = UniqueNameGenerator() + for array in nodes: + array_to_id[array] = id_gen("array") + if isinstance(array, InputArgumentBase): + input_arrays.append(array) + else: + internal_arrays.append(array) + + # Since 'asciidag' prints the DAG from top to bottom (ie, with the inputs + # at the bottom), we need to invert our representation of it, that is, the + # 'parents' constructor argument to Node() actually means 'children'. + from asciidag.node import Node # type: ignore[import] + asciidag_nodes: Dict[ArrayOrNames, Node] = {} + + from collections import defaultdict + asciidag_edges: Dict[ArrayOrNames, List[ArrayOrNames]] = defaultdict(list) + + # Reverse edge directions + for array in internal_arrays: + for _, v in nodes[array].edges.items(): + asciidag_edges[v].append(array) + + # Add the internal arrays in reversed order + for array in internal_arrays[::-1]: + ary_edges = [asciidag_nodes[v] for v in asciidag_edges[array]] + + if array == internal_arrays[-1]: + ary_edges.append(Node("Outputs")) + + asciidag_nodes[array] = Node(f"{nodes[array].title}", + parents=ary_edges) + + # Add the input arrays last since they have no predecessors + for array in input_arrays: + ary_edges = [asciidag_nodes[v] for v in asciidag_edges[array]] + asciidag_nodes[array] = Node(f"{nodes[array].title}", parents=ary_edges) + + input_node = Node("Inputs", parents=[asciidag_nodes[v] for v in input_arrays]) + + from asciidag.graph import Graph # type: ignore[import] + from io import StringIO + + f = StringIO() + graph = Graph(fh=f, use_color=use_color) + + graph.show_nodes([input_node]) + + # Get the graph and remove trailing whitespace + res = "\n".join([s.rstrip() for s in f.getvalue().split("\n")]) + + return res + + +def show_ascii_graph(result: Union[Array, DictOfNamedArrays]) -> None: + """Print a graph representing the computation of *result* to stdout using the + `asciidag `_ package. + + :arg result: Outputs of the computation (cf. + :func:`pytato.generate_loopy`) or the output of :func:`get_dot_graph`. + """ + + print(get_ascii_graph(result, use_color=True)) diff --git a/pytato/visualization.py b/pytato/visualization/dot.py similarity index 86% rename from pytato/visualization.py rename to pytato/visualization/dot.py index 803ceb0..6f302fa 100644 --- a/pytato/visualization.py +++ b/pytato/visualization/dot.py @@ -59,8 +59,6 @@ __doc__ = """ .. autofunction:: get_dot_graph .. autofunction:: get_dot_graph_from_partition .. autofunction:: show_dot_graph -.. autofunction:: get_ascii_graph -.. autofunction:: show_ascii_graph """ @@ -553,91 +551,4 @@ def show_dot_graph(result: Union[str, Array, DictOfNamedArrays, from pytools.graphviz import show_dot show_dot(dot_code, **kwargs) - -# {{{ Show ASCII representation of DAG - -def get_ascii_graph(result: Union[Array, DictOfNamedArrays], - use_color: bool = True) -> str: - """Return a string representing the computation of *result* - using the `asciidag `_ package. - - :arg result: Outputs of the computation (cf. - :func:`pytato.generate_loopy`). - :arg use_color: Colorized output - """ - outputs: DictOfNamedArrays = normalize_outputs(result) - del result - - mapper = ArrayToDotNodeInfoMapper() - for elem in outputs._data.values(): - mapper(elem) - - nodes = mapper.nodes - - input_arrays: List[Array] = [] - internal_arrays: List[ArrayOrNames] = [] - array_to_id: Dict[ArrayOrNames, str] = {} - - id_gen = UniqueNameGenerator() - for array in nodes: - array_to_id[array] = id_gen("array") - if isinstance(array, InputArgumentBase): - input_arrays.append(array) - else: - internal_arrays.append(array) - - # Since 'asciidag' prints the DAG from top to bottom (ie, with the inputs - # at the bottom), we need to invert our representation of it, that is, the - # 'parents' constructor argument to Node() actually means 'children'. - from asciidag.node import Node # type: ignore[import] - asciidag_nodes: Dict[ArrayOrNames, Node] = {} - - from collections import defaultdict - asciidag_edges: Dict[ArrayOrNames, List[ArrayOrNames]] = defaultdict(list) - - # Reverse edge directions - for array in internal_arrays: - for _, v in nodes[array].edges.items(): - asciidag_edges[v].append(array) - - # Add the internal arrays in reversed order - for array in internal_arrays[::-1]: - ary_edges = [asciidag_nodes[v] for v in asciidag_edges[array]] - - if array == internal_arrays[-1]: - ary_edges.append(Node("Outputs")) - - asciidag_nodes[array] = Node(f"{nodes[array].title}", - parents=ary_edges) - - # Add the input arrays last since they have no predecessors - for array in input_arrays: - ary_edges = [asciidag_nodes[v] for v in asciidag_edges[array]] - asciidag_nodes[array] = Node(f"{nodes[array].title}", parents=ary_edges) - - input_node = Node("Inputs", parents=[asciidag_nodes[v] for v in input_arrays]) - - from asciidag.graph import Graph # type: ignore[import] - from io import StringIO - - f = StringIO() - graph = Graph(fh=f, use_color=use_color) - - graph.show_nodes([input_node]) - - # Get the graph and remove trailing whitespace - res = "\n".join([s.rstrip() for s in f.getvalue().split("\n")]) - - return res - - -def show_ascii_graph(result: Union[Array, DictOfNamedArrays]) -> None: - """Print a graph representing the computation of *result* to stdout using the - `asciidag `_ package. - - :arg result: Outputs of the computation (cf. - :func:`pytato.generate_loopy`) or the output of :func:`get_dot_graph`. - """ - - print(get_ascii_graph(result, use_color=True)) # }}} -- GitLab