From ea02d453ce4df68c6bebfe7e6d84f6b1f42bd025 Mon Sep 17 00:00:00 2001
From: Matt Wala <wala1@illinois.edu>
Date: Mon, 27 Jul 2020 21:42:10 -0500
Subject: [PATCH] Add graph visualization support in pytools.visualization

---
 .gitignore                |   3 +
 doc/reference.rst         |   1 +
 examples/visualization.py |  40 ++++++
 pytato/__init__.py        |   3 +
 pytato/codegen.py         |  37 ++++--
 pytato/visualization.py   | 261 ++++++++++++++++++++++++++++++++++++++
 requirements.txt          |   2 +-
 7 files changed, 332 insertions(+), 15 deletions(-)
 create mode 100755 examples/visualization.py
 create mode 100644 pytato/visualization.py

diff --git a/.gitignore b/.gitignore
index 1b2f54d..8e4ee63 100644
--- a/.gitignore
+++ b/.gitignore
@@ -18,3 +18,6 @@ distribute*tar.gz
 doc/_build
 
 .mypy_cache
+
+*.dot
+*.svg
diff --git a/doc/reference.rst b/doc/reference.rst
index 43f6555..581f7b6 100644
--- a/doc/reference.rst
+++ b/doc/reference.rst
@@ -7,3 +7,4 @@ Reference
 .. automodule:: pytato.program
 .. automodule:: pytato.target
 .. automodule:: pytato.codegen
+.. automodule:: pytato.visualization
diff --git a/examples/visualization.py b/examples/visualization.py
new file mode 100755
index 0000000..d0150dd
--- /dev/null
+++ b/examples/visualization.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python
+"""Demonstrates graph visualization with Graphviz."""
+
+import logging
+import numpy as np
+import subprocess
+
+import pytato as pt
+
+
+logger = logging.getLogger(__name__)
+
+
+GRAPH_DOT = "graph.dot"
+GRAPH_SVG = "graph.svg"
+
+
+def main():
+    ns = pt.Namespace()
+
+    pt.make_size_param(ns, "n")
+    array = pt.make_placeholder(ns, "array", shape="n", dtype=np.float)
+    stack = pt.stack([array, 2*array, array + 6])
+    ns.assign("stack", stack)
+    result = stack @ stack.T
+
+    from pytato.visualization import get_dot_graph
+    dot_code = get_dot_graph(result)
+
+    with open(GRAPH_DOT, "w") as outf:
+        outf.write(dot_code)
+    logger.info("wrote '%s'", GRAPH_DOT)
+
+    subprocess.run(["dot", "-Tsvg", GRAPH_DOT, "-o", GRAPH_SVG], check=True)
+    logger.info("wrote '%s'", GRAPH_SVG)
+
+
+if __name__ == "__main__":
+    logging.basicConfig(level=logging.INFO)
+    main()
diff --git a/pytato/__init__.py b/pytato/__init__.py
index 5bc498c..06a3d83 100644
--- a/pytato/__init__.py
+++ b/pytato/__init__.py
@@ -36,6 +36,7 @@ from pytato.array import (
 
 from pytato.codegen import generate_loopy
 from pytato.target import Target, PyOpenCLTarget
+from pytato.visualization import get_dot_graph, show_dot_graph
 
 __all__ = (
         "DottedName", "Namespace", "Array", "DictOfNamedArrays",
@@ -49,4 +50,6 @@ __all__ = (
         "generate_loopy",
 
         "Target", "PyOpenCLTarget",
+
+        "get_dot_graph", "show_dot_graph",
 )
diff --git a/pytato/codegen.py b/pytato/codegen.py
index f74051c..3a742b8 100644
--- a/pytato/codegen.py
+++ b/pytato/codegen.py
@@ -72,6 +72,7 @@ Code Generation Internals
 .. autofunction:: get_loopy_temporary
 .. autofunction:: add_store
 .. autofunction:: rename_reductions
+.. autofunction:: normalize_outputs
 
 """
 
@@ -689,6 +690,26 @@ def rename_reductions(
     loopy_expr_context.reduction_bounds = new_reduction_bounds
     return result
 
+
+def normalize_outputs(result: Union[Array, DictOfNamedArrays]) -> DictOfNamedArrays:
+    """Convert outputs of a computation to the canonical form.
+
+    Performs a conversion to :class:`~pytato.DictOfNamedArrays` if necessary.
+
+    :param result: Outputs of the computation.
+    """
+    if not isinstance(result, (Array, DictOfNamedArrays)):
+        raise TypeError("outputs of the computation should be "
+                "either an Array or a DictOfNamedArrays")
+
+    if isinstance(result, Array):
+        outputs = DictOfNamedArrays({"_pt_out": result})
+    else:
+        assert isinstance(outputs, DictOfNamedArrays)
+        outputs = result
+
+    return outputs
+
 # }}}
 
 
@@ -702,21 +723,9 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays],
     :param options: Code generation options for the kernel.
     :returns: A wrapped generated :mod:`loopy` kernel
     """
-    # {{{ get namespace and outputs
-
-    outputs: DictOfNamedArrays
-
-    if isinstance(result, Array):
-        outputs = DictOfNamedArrays({"_pt_out": result})
-        namespace = outputs.namespace
-    else:
-        assert isinstance(result, DictOfNamedArrays)
-        outputs = result
-
-    namespace = outputs.namespace
+    outputs: DictOfNamedArrays = normalize_outputs(result)
     del result
-
-    # }}}
+    namespace = outputs.namespace
 
     if target is None:
         target = PyOpenCLTarget()
diff --git a/pytato/visualization.py b/pytato/visualization.py
new file mode 100644
index 0000000..be264df
--- /dev/null
+++ b/pytato/visualization.py
@@ -0,0 +1,261 @@
+from __future__ import annotations
+
+__copyright__ = """
+Copyright (C) 2020 Matt Wala
+"""
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+
+import contextlib
+import dataclasses
+import html
+import itertools
+from typing import Callable, Dict, Union, Iterator, List, Mapping, Union
+
+from pytools import UniqueNameGenerator
+from pytools.codegen import CodeGenerator as CodeGeneratorBase
+
+from pytato.array import (
+        Array, DictOfNamedArrays, IndexLambda, InputArgumentBase,
+        Stack, ShapeType, TagsType)
+from pytato.codegen import normalize_outputs
+import pytato.transform
+
+
+__doc__ = """
+.. currentmodule:: pytato
+
+Graph Visualization
+-------------------
+
+.. autofunction:: get_dot_graph
+.. autofunction:: show_dot_graph
+
+"""
+
+
+# {{{ array -> dot node converter
+
+@dataclasses.dataclass
+class DotNodeInfo:
+    title: str
+    fields: Dict[str, str]
+    edges: Dict[str, Array]
+
+
+def stringify_tags(tags: TagsType) -> str:
+    components = sorted(str(elem) for elem in tags)
+    return "{" + ", ".join(components) + "}"
+
+
+def stringify_shape(shape: ShapeType) -> str:
+    components = [str(elem) for elem in shape]
+    if not components:
+        components = [","]
+    elif len(components) == 1:
+        components[0] += ","
+    return "(" + ", ".join(components) + ")"
+
+
+class ArrayToDotNodeInfoMapper(pytato.transform.Mapper):
+    def get_common_dot_info(self, expr: Array) -> DotNodeInfo:
+        title = type(expr).__name__
+        fields = dict(shape=stringify_shape(expr.shape),
+                dtype=str(expr.dtype),
+                tags=stringify_tags(expr.tags))
+        edges: Dict[str, Array] = {}
+        return DotNodeInfo(title, fields, edges)
+
+    def handle_unsupported_array(self, expr: Array,  # type: ignore
+            nodes: Dict[Array, DotNodeInfo]) -> None:
+        # Default handler, does its best to guess how to handle fields.
+        if expr in nodes:
+            return
+        info = self.get_common_dot_info(expr)
+
+        for field in expr.fields:
+            if field in info.fields:
+                continue
+            attr = getattr(expr, field)
+
+            if isinstance(attr, Array):
+                self.rec(attr, nodes)
+                info.edges[field] = attr
+            elif isinstance(attr, tuple):
+                info.fields[field] = stringify_shape(attr)
+            else:
+                info.fields[field] = str(attr)
+
+        nodes[expr] = info
+
+    def map_index_lambda(self, expr: IndexLambda,
+            nodes: Dict[Array, DotNodeInfo]) -> None:
+        if expr in nodes:
+            return
+
+        info = self.get_common_dot_info(expr)
+        info.fields["expr"] = str(expr.expr)
+
+        for name, val in expr.bindings.items():
+            self.rec(val, nodes)
+            info.edges[name] = val
+
+        nodes[expr] = info
+
+    def map_stack(self, expr: Stack, nodes: Dict[Array, DotNodeInfo]) -> None:
+        if expr in nodes:
+            return
+
+        info = self.get_common_dot_info(expr)
+        info.fields["axis"] = str(expr.axis)
+
+        for i, array in enumerate(expr.arrays):
+            self.rec(array, nodes)
+            info.edges[str(i)] = array
+
+        nodes[expr] = info
+
+
+def dot_escape(s: str) -> str:
+    # "\" and HTML are significant in graphviz.
+    return html.escape(s.replace("\\", "\\\\"))
+
+
+class DotEmitter(CodeGeneratorBase):
+    @contextlib.contextmanager
+    def block(self, name: str) -> Iterator[None]:
+        self(name + " {")
+        self.indent()
+        yield
+        self.dedent()
+        self("}")
+
+
+def _emit_array(emit: DotEmitter, info: DotNodeInfo, id: str) -> None:
+    td_attrib = "border=\"0\""
+    table_attrib = "border=\"0\" cellborder=\"1\" cellspacing=\"0\""
+
+    rows = ["<tr><td colspan=\"2\" %s>%s</td></tr>"
+            % (td_attrib, dot_escape(info.title))]
+
+    for name, field in info.fields.items():
+        rows.append(
+                "<tr><td %s>%s:</td><td %s>%s</td></tr>"
+                % (td_attrib, dot_escape(name), td_attrib, dot_escape(field)))
+
+    table = "<table %s>\n%s</table>" % (table_attrib, "".join(rows))
+    emit("%s [label=<%s>]" % (id, table))
+
+
+def _emit_name_cluster(emit: DotEmitter, names: Mapping[str, Array],
+        array_to_id: Mapping[Array, str], id_gen: Callable[[str], str],
+        label: str) -> None:
+    edges = []
+
+    with emit.block("subgraph cluster_%s" % label):
+        emit("node [shape=ellipse]")
+        emit("label=\"%s\"" % label)
+
+        for name, array in names.items():
+            name_id = id_gen(label)
+            emit("%s [label=\"%s\"]" % (name_id, dot_escape(name)))
+            array_id = array_to_id[array]
+            # Edges must be outside the cluster.
+            edges.append((name_id, array_id))
+
+    for name_id, array_id in edges:
+        emit("%s -> %s" % (name_id, array_id))
+
+# }}}
+
+
+def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str:
+    r"""Return a string in the `dot <http://graphviz.org>`_ language depicting the
+    graph of the computation of *result*.
+
+    :arg result: Outputs of the computation (cf. :func:`pytato.generate_loopy`).
+    """
+    outputs: DictOfNamedArrays = normalize_outputs(result)
+    del result
+    namespace = outputs.namespace
+
+    nodes: Dict[Array, DotNodeInfo] = {}
+    mapper = ArrayToDotNodeInfoMapper()
+    for elem in itertools.chain(namespace.values(), outputs.values()):
+        mapper(elem, nodes)
+
+    input_arrays: List[Array] = []
+    internal_arrays: List[Array] = []
+    array_to_id: Dict[Array, str] = {}
+
+    id_gen = UniqueNameGenerator()
+    for array in nodes:
+        array_to_id[array] = id_gen("array")
+        if isinstance(array, InputArgumentBase):
+            input_arrays.append(array)
+        else:
+            internal_arrays.append(array)
+
+    emit = DotEmitter()
+
+    with emit.block("digraph computation"):
+        emit("node [shape=rectangle]")
+
+        # Emit inputs.
+        with emit.block("subgraph cluster_Inputs"):
+            emit("label=\"Inputs\"")
+            for array in input_arrays:
+                _emit_array(emit, nodes[array], array_to_id[array])
+
+        # Emit non-inputs.
+        for array in internal_arrays:
+            _emit_array(emit, nodes[array], array_to_id[array])
+
+        # Emit edges.
+        for array, node in nodes.items():
+            for label, tail_array in node.edges.items():
+                tail = array_to_id[tail_array]
+                head = array_to_id[array]
+                emit("%s -> %s [label=\"%s\"]" % (tail, head, dot_escape(label)))
+
+        # Emit output/namespace name mappings.
+        _emit_name_cluster(emit, outputs, array_to_id, id_gen, label="Outputs")
+        _emit_name_cluster(emit, namespace, array_to_id, id_gen, label="Namespace")
+
+    output: str = emit.get()
+    return output
+
+
+def show_dot_graph(result: Union[Array, DictOfNamedArrays]) -> None:
+    """Show a graph representing the computation of *result* in a browser.
+
+    :arg result: Outputs of the computation (cf. :func:`pytato.generate_loopy`).
+    """
+    dot_code: str
+
+    if isinstance(result, str):
+        dot_code = result
+    else:
+        dot_code = get_dot_graph(result)
+
+    from pymbolic.imperative.utils import show_dot
+    show_dot(dot_code)
diff --git a/requirements.txt b/requirements.txt
index 8900766..e489999 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,2 +1,2 @@
-git+https://github.com/inducer/pytools.git
+git+https://github.com/mattwala/pytools.git@code-generator
 git+https://github.com/inducer/loopy.git
-- 
GitLab