diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 7c34e94ae8a63f1374259a9f84203e1b4edb4785..ff6eeb4a9c6c6f3278efe79d5b1f0bf9b4882e82 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -11,3 +11,5 @@ dependencies: - pyopencl - islpy - sphinx-autodoc-typehints +- pip: + - asciidag diff --git a/examples/visualization.py b/examples/visualization.py index ac71e6060b60adcf480e1864018ff1a57fa5d8b8..2b569a3925f464f097943c603eafc30ad863d0c8 100755 --- a/examples/visualization.py +++ b/examples/visualization.py @@ -22,6 +22,8 @@ def main(): stack = pt.stack([array, 2*array, array + 6]) result = stack @ stack.T + pt.show_ascii_graph(result) + dot_code = pt.get_dot_graph(result) with open(GRAPH_DOT, "w") as outf: diff --git a/pytato/__init__.py b/pytato/__init__.py index 77e868567a209e9211531ac9c167316ff772109a..a735334e60efbbe2338eeca95cb74a1c641a88b1 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -59,7 +59,8 @@ from pytato.loopy import LoopyCall from pytato.target.loopy.codegen import generate_loopy from pytato.target import Target from pytato.target.loopy import LoopyPyOpenCLTarget -from pytato.visualization import get_dot_graph, show_dot_graph +from pytato.visualization import (get_dot_graph, show_dot_graph, + get_ascii_graph, show_ascii_graph) __all__ = ( "Array", "AbstractResultWithNamedArrays", "DictOfNamedArrays", @@ -75,7 +76,7 @@ __all__ = ( "Target", "LoopyPyOpenCLTarget", - "get_dot_graph", "show_dot_graph", + "get_dot_graph", "show_dot_graph", "get_ascii_graph", "show_ascii_graph", "abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", "sinh", "cosh", diff --git a/pytato/visualization.py b/pytato/visualization.py index 976cceb51a9d284b044a1adbb1f413e1ddb270ed..5708400f022ed40fbfbefda06ec8f8e79630610e 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -46,6 +46,8 @@ __doc__ = """ .. autofunction:: get_dot_graph .. autofunction:: show_dot_graph +.. autofunction:: get_ascii_graph +.. autofunction:: show_ascii_graph """ @@ -257,11 +259,11 @@ def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: return emit.get() -def show_dot_graph(result: Union[Array, DictOfNamedArrays]) -> None: +def show_dot_graph(result: Union[str, Array, DictOfNamedArrays]) -> None: """Show a graph representing the computation of *result* in a browser. :arg result: Outputs of the computation (cf. - :func:`pytato.generate_loopy`). + :func:`pytato.generate_loopy`) or the output of :func:`get_dot_graph`. """ dot_code: str @@ -272,3 +274,90 @@ def show_dot_graph(result: Union[Array, DictOfNamedArrays]) -> None: from pymbolic.imperative.utils import show_dot show_dot(dot_code) + + +# {{{ Show ASCII representation of DAG + +def get_ascii_graph(result: Union[Array, DictOfNamedArrays], + use_color: bool = True) -> str: + """Return a string representing the computation of *result* + using the `asciidag <https://pypi.org/project/asciidag/>`_ package. + + :arg result: Outputs of the computation (cf. + :func:`pytato.generate_loopy`). + :arg use_color: Colorized output + """ + outputs: DictOfNamedArrays = normalize_outputs(result) + del result + + nodes: Dict[Array, DotNodeInfo] = {} + mapper = ArrayToDotNodeInfoMapper() + for elem in outputs._data.values(): + mapper(elem, nodes) + + input_arrays: List[Array] = [] + internal_arrays: List[Array] = [] + array_to_id: Dict[Array, str] = {} + + id_gen = UniqueNameGenerator() + for array in nodes: + array_to_id[array] = id_gen("array") + if isinstance(array, InputArgumentBase): + input_arrays.append(array) + else: + internal_arrays.append(array) + + # Since 'asciidag' prints the DAG from top to bottom (ie, with the inputs + # at the bottom), we need to invert our representation of it, that is, the + # 'parents' constructor argument to Node() actually means 'children'. + from asciidag.node import Node # type: ignore[import] + asciidag_nodes: Dict[Array, Node] = {} + + from collections import defaultdict + asciidag_edges: Dict[Array, List[Array]] = defaultdict(list) + + # Reverse edge directions + for array in internal_arrays: + for _, v in nodes[array].edges.items(): + asciidag_edges[v].append(array) + + # Add the internal arrays in reversed order + for array in internal_arrays[::-1]: + ary_edges = [asciidag_nodes[v] for v in asciidag_edges[array]] + + if array == internal_arrays[-1]: + ary_edges.append(Node("Outputs")) + + asciidag_nodes[array] = Node(f"{nodes[array].title}", + parents=ary_edges) + + # Add the input arrays last since they have no predecessors + for array in input_arrays: + ary_edges = [asciidag_nodes[v] for v in asciidag_edges[array]] + asciidag_nodes[array] = Node(f"{nodes[array].title}", parents=ary_edges) + + input_node = Node("Inputs", parents=[asciidag_nodes[v] for v in input_arrays]) + + from asciidag.graph import Graph # type: ignore[import] + from io import StringIO + + f = StringIO() + graph = Graph(fh=f, use_color=use_color) + + graph.show_nodes([input_node]) + + # Get the graph and remove trailing whitespace + res = "\n".join([s.rstrip() for s in f.getvalue().split("\n")]) + + return res + + +def show_ascii_graph(result: Union[Array, DictOfNamedArrays]) -> None: + """Show a graph representing the computation of *result* in a browser. + + :arg result: Outputs of the computation (cf. + :func:`pytato.generate_loopy`) or the output of :func:`get_dot_graph`. + """ + + print(get_ascii_graph(result, use_color=True)) +# }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index fcbbb101881f143a6e174acbf25b1ba6965d7e78..35dec3accc3a0421981b3a6543ce99c65e3ef4a1 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -304,6 +304,35 @@ def test_toposortmapper(): assert isinstance(tm.topological_order[6], MatrixProduct) +def test_asciidag(): + n = pt.make_size_param("n") + array = pt.make_placeholder(name="array", shape=n, dtype=np.float64) + stack = pt.stack([array, 2*array, array + 6]) + y = stack @ stack.T + + from pytato import get_ascii_graph + + res = get_ascii_graph(y, use_color=False) + + ref_str = r"""* Inputs +*-. Placeholder +|\ \ +* | | IndexLambda +| |/ +|/| +| * IndexLambda +|/ +* Stack +|\ +* | AxisPermutation +|/ +* MatrixProduct +* Outputs +""" + + assert res == ref_str + + def test_linear_complexity_inequality(): # See https://github.com/inducer/pytato/issues/163 import pytato as pt