From 30696fcdcb1fce1e538531829e01ab750193015c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 20 Oct 2021 15:30:40 -0500 Subject: [PATCH] add asciidag visualization (#167) * initial asciidag implementation * cleanups * pip * fix lint errors * more comments * clarify doc * extract get_ascii_graph * add test * remove trailing ws * use_color argument --- .test-conda-env-py3.yml | 2 + examples/visualization.py | 2 + pytato/__init__.py | 5 ++- pytato/visualization.py | 93 ++++++++++++++++++++++++++++++++++++++- test/test_pytato.py | 29 ++++++++++++ 5 files changed, 127 insertions(+), 4 deletions(-) diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 7c34e94..ff6eeb4 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -11,3 +11,5 @@ dependencies: - pyopencl - islpy - sphinx-autodoc-typehints +- pip: + - asciidag diff --git a/examples/visualization.py b/examples/visualization.py index ac71e60..2b569a3 100755 --- a/examples/visualization.py +++ b/examples/visualization.py @@ -22,6 +22,8 @@ def main(): stack = pt.stack([array, 2*array, array + 6]) result = stack @ stack.T + pt.show_ascii_graph(result) + dot_code = pt.get_dot_graph(result) with open(GRAPH_DOT, "w") as outf: diff --git a/pytato/__init__.py b/pytato/__init__.py index 77e8685..a735334 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -59,7 +59,8 @@ from pytato.loopy import LoopyCall from pytato.target.loopy.codegen import generate_loopy from pytato.target import Target from pytato.target.loopy import LoopyPyOpenCLTarget -from pytato.visualization import get_dot_graph, show_dot_graph +from pytato.visualization import (get_dot_graph, show_dot_graph, + get_ascii_graph, show_ascii_graph) __all__ = ( "Array", "AbstractResultWithNamedArrays", "DictOfNamedArrays", @@ -75,7 +76,7 @@ __all__ = ( "Target", "LoopyPyOpenCLTarget", - "get_dot_graph", "show_dot_graph", + "get_dot_graph", "show_dot_graph", "get_ascii_graph", "show_ascii_graph", "abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", "sinh", "cosh", diff --git a/pytato/visualization.py b/pytato/visualization.py index 976cceb..5708400 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -46,6 +46,8 @@ __doc__ = """ .. autofunction:: get_dot_graph .. autofunction:: show_dot_graph +.. autofunction:: get_ascii_graph +.. autofunction:: show_ascii_graph """ @@ -257,11 +259,11 @@ def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: return emit.get() -def show_dot_graph(result: Union[Array, DictOfNamedArrays]) -> None: +def show_dot_graph(result: Union[str, Array, DictOfNamedArrays]) -> None: """Show a graph representing the computation of *result* in a browser. :arg result: Outputs of the computation (cf. - :func:`pytato.generate_loopy`). + :func:`pytato.generate_loopy`) or the output of :func:`get_dot_graph`. """ dot_code: str @@ -272,3 +274,90 @@ def show_dot_graph(result: Union[Array, DictOfNamedArrays]) -> None: from pymbolic.imperative.utils import show_dot show_dot(dot_code) + + +# {{{ Show ASCII representation of DAG + +def get_ascii_graph(result: Union[Array, DictOfNamedArrays], + use_color: bool = True) -> str: + """Return a string representing the computation of *result* + using the `asciidag `_ package. + + :arg result: Outputs of the computation (cf. + :func:`pytato.generate_loopy`). + :arg use_color: Colorized output + """ + outputs: DictOfNamedArrays = normalize_outputs(result) + del result + + nodes: Dict[Array, DotNodeInfo] = {} + mapper = ArrayToDotNodeInfoMapper() + for elem in outputs._data.values(): + mapper(elem, nodes) + + input_arrays: List[Array] = [] + internal_arrays: List[Array] = [] + array_to_id: Dict[Array, str] = {} + + id_gen = UniqueNameGenerator() + for array in nodes: + array_to_id[array] = id_gen("array") + if isinstance(array, InputArgumentBase): + input_arrays.append(array) + else: + internal_arrays.append(array) + + # Since 'asciidag' prints the DAG from top to bottom (ie, with the inputs + # at the bottom), we need to invert our representation of it, that is, the + # 'parents' constructor argument to Node() actually means 'children'. + from asciidag.node import Node # type: ignore[import] + asciidag_nodes: Dict[Array, Node] = {} + + from collections import defaultdict + asciidag_edges: Dict[Array, List[Array]] = defaultdict(list) + + # Reverse edge directions + for array in internal_arrays: + for _, v in nodes[array].edges.items(): + asciidag_edges[v].append(array) + + # Add the internal arrays in reversed order + for array in internal_arrays[::-1]: + ary_edges = [asciidag_nodes[v] for v in asciidag_edges[array]] + + if array == internal_arrays[-1]: + ary_edges.append(Node("Outputs")) + + asciidag_nodes[array] = Node(f"{nodes[array].title}", + parents=ary_edges) + + # Add the input arrays last since they have no predecessors + for array in input_arrays: + ary_edges = [asciidag_nodes[v] for v in asciidag_edges[array]] + asciidag_nodes[array] = Node(f"{nodes[array].title}", parents=ary_edges) + + input_node = Node("Inputs", parents=[asciidag_nodes[v] for v in input_arrays]) + + from asciidag.graph import Graph # type: ignore[import] + from io import StringIO + + f = StringIO() + graph = Graph(fh=f, use_color=use_color) + + graph.show_nodes([input_node]) + + # Get the graph and remove trailing whitespace + res = "\n".join([s.rstrip() for s in f.getvalue().split("\n")]) + + return res + + +def show_ascii_graph(result: Union[Array, DictOfNamedArrays]) -> None: + """Show a graph representing the computation of *result* in a browser. + + :arg result: Outputs of the computation (cf. + :func:`pytato.generate_loopy`) or the output of :func:`get_dot_graph`. + """ + + print(get_ascii_graph(result, use_color=True)) +# }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index fcbbb10..35dec3a 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -304,6 +304,35 @@ def test_toposortmapper(): assert isinstance(tm.topological_order[6], MatrixProduct) +def test_asciidag(): + n = pt.make_size_param("n") + array = pt.make_placeholder(name="array", shape=n, dtype=np.float64) + stack = pt.stack([array, 2*array, array + 6]) + y = stack @ stack.T + + from pytato import get_ascii_graph + + res = get_ascii_graph(y, use_color=False) + + ref_str = r"""* Inputs +*-. Placeholder +|\ \ +* | | IndexLambda +| |/ +|/| +| * IndexLambda +|/ +* Stack +|\ +* | AxisPermutation +|/ +* MatrixProduct +* Outputs +""" + + assert res == ref_str + + def test_linear_complexity_inequality(): # See https://github.com/inducer/pytato/issues/163 import pytato as pt -- GitLab