diff --git a/.gitignore b/.gitignore index 1b2f54db1ea48ac89fb5db925859ebf82f9f895b..8e4ee63f2a6f59f0ddccce9f8aa05f1ff578f77f 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,6 @@ distribute*tar.gz doc/_build .mypy_cache + +*.dot +*.svg diff --git a/doc/reference.rst b/doc/reference.rst index 43f655519a9e240bafef04372771f415ffed42a3..581f7b65b204ba8f2a1736ddc8d75d82e484ee34 100644 --- a/doc/reference.rst +++ b/doc/reference.rst @@ -7,3 +7,4 @@ Reference .. automodule:: pytato.program .. automodule:: pytato.target .. automodule:: pytato.codegen +.. automodule:: pytato.visualization diff --git a/examples/visualization.py b/examples/visualization.py new file mode 100755 index 0000000000000000000000000000000000000000..f41f33b93c38dc0fdfbaae47a3e1338f30502177 --- /dev/null +++ b/examples/visualization.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +"""Demonstrates graph visualization with Graphviz.""" + +import logging +import numpy as np +import shutil +import subprocess + +import pytato as pt + + +logger = logging.getLogger(__name__) + + +GRAPH_DOT = "graph.dot" +GRAPH_SVG = "graph.svg" + + +def main(): + ns = pt.Namespace() + + pt.make_size_param(ns, "n") + array = pt.make_placeholder(ns, "array", shape="n", dtype=np.float) + stack = pt.stack([array, 2*array, array + 6]) + ns.assign("stack", stack) + result = stack @ stack.T + + dot_code = pt.get_dot_graph(result) + + with open(GRAPH_DOT, "w") as outf: + outf.write(dot_code) + logger.info("wrote '%s'", GRAPH_DOT) + + dot_path = shutil.which("dot") + if dot_path is not None: + subprocess.run([dot_path, "-Tsvg", GRAPH_DOT, "-o", GRAPH_SVG], check=True) + logger.info("wrote '%s'", GRAPH_SVG) + else: + logger.info("'dot' executable not found; cannot convert to SVG") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/pytato/__init__.py b/pytato/__init__.py index 8eab25081837165f31abeeeefe8bdf3008fdde8e..d65fefcc8dbb71f20edbae81a4ed6cca0940b0f2 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -36,6 +36,7 @@ from pytato.array import ( from pytato.codegen import generate_loopy from pytato.target import Target, PyOpenCLTarget +from pytato.visualization import get_dot_graph, show_dot_graph __all__ = ( "DottedName", "Namespace", "Array", "DictOfNamedArrays", @@ -49,4 +50,6 @@ __all__ = ( "generate_loopy", "Target", "PyOpenCLTarget", + + "get_dot_graph", "show_dot_graph", ) diff --git a/pytato/codegen.py b/pytato/codegen.py index 5a62bb278dd2bee49640990d7f6630ef065372e0..70a2694532772e00a27f27c451880717a45f400b 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -76,6 +76,7 @@ Code Generation Internals .. autofunction:: get_loopy_temporary .. autofunction:: add_store .. autofunction:: rename_reductions +.. autofunction:: normalize_outputs .. autofunction:: get_initial_codegen_state .. autofunction:: preprocess @@ -733,6 +734,26 @@ def rename_reductions( return result +def normalize_outputs(result: Union[Array, DictOfNamedArrays]) -> DictOfNamedArrays: + """Convert outputs of a computation to the canonical form. + + Performs a conversion to :class:`~pytato.DictOfNamedArrays` if necessary. + + :param result: Outputs of the computation. + """ + if not isinstance(result, (Array, DictOfNamedArrays)): + raise TypeError("outputs of the computation should be " + "either an Array or a DictOfNamedArrays") + + if isinstance(result, Array): + outputs = DictOfNamedArrays({"_pt_out": result}) + else: + assert isinstance(result, DictOfNamedArrays) + outputs = result + + return outputs + + def get_initial_codegen_state(namespace: Namespace, target: Target, options: Optional[lp.Options]) -> CodeGenState: kernel = lp.make_kernel("{:}", [], @@ -774,20 +795,9 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays], :param options: Code generation options for the kernel. :returns: A wrapped generated :mod:`loopy` kernel """ - # {{{ get namespace and outputs - - orig_outputs: DictOfNamedArrays - - if isinstance(result, Array): - orig_outputs = DictOfNamedArrays({"_pt_out": result}) - else: - assert isinstance(result, DictOfNamedArrays) - orig_outputs = result - + orig_outputs: DictOfNamedArrays = normalize_outputs(result) del result - # }}} - if target is None: target = PyOpenCLTarget() diff --git a/pytato/visualization.py b/pytato/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..421ad2af746a510c694f47328f2e163772d64c68 --- /dev/null +++ b/pytato/visualization.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +__copyright__ = """ +Copyright (C) 2020 Matt Wala +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +import contextlib +import dataclasses +import html +import itertools +from typing import Callable, Dict, Union, Iterator, List, Mapping + +from pytools import UniqueNameGenerator +from pytools.codegen import CodeGenerator as CodeGeneratorBase + +from pytato.array import ( + Array, DictOfNamedArrays, IndexLambda, InputArgumentBase, + Stack, ShapeType, TagsType) +from pytato.codegen import normalize_outputs +import pytato.transform + + +__doc__ = """ +.. currentmodule:: pytato + +Graph Visualization +------------------- + +.. autofunction:: get_dot_graph +.. autofunction:: show_dot_graph + +""" + + +# {{{ array -> dot node converter + +@dataclasses.dataclass +class DotNodeInfo: + title: str + fields: Dict[str, str] + edges: Dict[str, Array] + + +def stringify_tags(tags: TagsType) -> str: + components = sorted(str(elem) for elem in tags) + return "{" + ", ".join(components) + "}" + + +def stringify_shape(shape: ShapeType) -> str: + components = [str(elem) for elem in shape] + if not components: + components = [","] + elif len(components) == 1: + components[0] += "," + return "(" + ", ".join(components) + ")" + + +class ArrayToDotNodeInfoMapper(pytato.transform.Mapper): + def get_common_dot_info(self, expr: Array) -> DotNodeInfo: + title = type(expr).__name__ + fields = dict(shape=stringify_shape(expr.shape), + dtype=str(expr.dtype), + tags=stringify_tags(expr.tags)) + edges: Dict[str, Array] = {} + return DotNodeInfo(title, fields, edges) + + def handle_unsupported_array(self, expr: Array, # type: ignore + nodes: Dict[Array, DotNodeInfo]) -> None: + # Default handler, does its best to guess how to handle fields. + if expr in nodes: + return + info = self.get_common_dot_info(expr) + + for field in expr.fields: + if field in info.fields: + continue + attr = getattr(expr, field) + + if isinstance(attr, Array): + self.rec(attr, nodes) + info.edges[field] = attr + elif isinstance(attr, tuple): + info.fields[field] = stringify_shape(attr) + else: + info.fields[field] = str(attr) + + nodes[expr] = info + + def map_index_lambda(self, expr: IndexLambda, + nodes: Dict[Array, DotNodeInfo]) -> None: + if expr in nodes: + return + + info = self.get_common_dot_info(expr) + info.fields["expr"] = str(expr.expr) + + for name, val in expr.bindings.items(): + self.rec(val, nodes) + info.edges[name] = val + + nodes[expr] = info + + def map_stack(self, expr: Stack, nodes: Dict[Array, DotNodeInfo]) -> None: + if expr in nodes: + return + + info = self.get_common_dot_info(expr) + info.fields["axis"] = str(expr.axis) + + for i, array in enumerate(expr.arrays): + self.rec(array, nodes) + info.edges[str(i)] = array + + nodes[expr] = info + + +def dot_escape(s: str) -> str: + # "\" and HTML are significant in graphviz. + return html.escape(s.replace("\\", "\\\\")) + + +class DotEmitter(CodeGeneratorBase): + @contextlib.contextmanager + def block(self, name: str) -> Iterator[None]: + self(name + " {") + self.indent() + yield + self.dedent() + self("}") + + +def _emit_array(emit: DotEmitter, info: DotNodeInfo, id: str) -> None: + td_attrib = "border=\"0\"" + table_attrib = "border=\"0\" cellborder=\"1\" cellspacing=\"0\"" + + rows = ["