From b953db0398a27d539bd36ca17085c050d45635f3 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <kaushikcfd@gmail.com> Date: Tue, 23 May 2023 12:56:29 -0500 Subject: [PATCH] Implement fancy placeholder focused data-flow viz --- pytato/__init__.py | 5 +- pytato/visualization/__init__.py | 4 + .../fancy_placeholder_data_flow.py | 300 ++++++++++++++++++ 3 files changed, 308 insertions(+), 1 deletion(-) create mode 100644 pytato/visualization/fancy_placeholder_data_flow.py diff --git a/pytato/__init__.py b/pytato/__init__.py index 731c9b0..0117d6b 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -90,7 +90,9 @@ from pytato.target.loopy import LoopyPyOpenCLTarget from pytato.target.python.jax import generate_jax from pytato.visualization import (get_dot_graph, show_dot_graph, get_ascii_graph, show_ascii_graph, - get_dot_graph_from_partition) + get_dot_graph_from_partition, + show_fancy_placeholder_data_flow, + ) import pytato.analysis as analysis import pytato.tags as tags import pytato.transform as transform @@ -133,6 +135,7 @@ __all__ = ( "get_dot_graph", "show_dot_graph", "get_ascii_graph", "show_ascii_graph", "get_dot_graph_from_partition", + "show_fancy_placeholder_data_flow", "abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", "sqrt", "conj", "arctan2", diff --git a/pytato/visualization/__init__.py b/pytato/visualization/__init__.py index af9220a..6fed19f 100644 --- a/pytato/visualization/__init__.py +++ b/pytato/visualization/__init__.py @@ -3,14 +3,18 @@ .. automodule:: pytato.visualization.dot .. automodule:: pytato.visualization.ascii +.. automodule:: pytato.visualization.fancy_placeholder_data_flow """ from .dot import get_dot_graph, show_dot_graph, get_dot_graph_from_partition from .ascii import get_ascii_graph, show_ascii_graph +from .fancy_placeholder_data_flow import show_fancy_placeholder_data_flow __all__ = [ "get_dot_graph", "show_dot_graph", "get_dot_graph_from_partition", "get_ascii_graph", "show_ascii_graph", + + "show_fancy_placeholder_data_flow", ] diff --git a/pytato/visualization/fancy_placeholder_data_flow.py b/pytato/visualization/fancy_placeholder_data_flow.py new file mode 100644 index 0000000..e3e20fe --- /dev/null +++ b/pytato/visualization/fancy_placeholder_data_flow.py @@ -0,0 +1,300 @@ +""" +.. currentmodule:: pytato + +.. autofunction:: show_fancy_placeholder_data_flow +""" +from pytato.transform import CachedMapper +from dataclasses import dataclass +from pytools import UniqueNameGenerator +from typing import FrozenSet, Set, List, Tuple, Collection, Union, Any +from pytato.array import (Array, DictOfNamedArrays, Einsum, Stack, + Concatenate, IndexLambda, Placeholder, DataWrapper, + AdvancedIndexInContiguousAxes, + AdvancedIndexInNoncontiguousAxes, + IndexRemappingBase) + + +# {{{ Graph node colors + +PLACEHOLDER_COLOR = "lightgrey" +ELEMWISE_COLOR = "coral1" +OUTPUT_COLOR = "springgreen" +EINSUM_COLOR = "crimson" +STACK_CONCAT_COLOR = "deepskyblue" +INDIRECTION_COLOR = "darkblue" + +# }}} + +# {{{ Graph node shapes + +PLACEHOLDER_SHAPE = "ellipse" +OUTPUT_SHAPE = "ellipse" +ELEMWISE_SHAPE = "diamond" +EINSUM_SHAPE = "box3d" +STACK_CONCAT_SHAPE = "folder" +INDIRECTION_SHAPE = "hexagon" + +# }}} + + +@dataclass(frozen=True) +class _FancyDotWriterNode: + """ + Return type for :class:`FancyDotWriter`. + """ + pass + + +@dataclass(frozen=True) +class PlainOldDotNode(_FancyDotWriterNode): + """ + Node that would appear in the graphviz graph. + + .. attribute:: node_id + + ID of the node in the graphviz-based graph. + """ + node_id: str + + +@dataclass(frozen=True) +class NoShowNode(_FancyDotWriterNode): + """ + Node that will not appear in the graphviz graph being built by a + :class:`FancyDotWriter`. + """ + + +def _get_dot_node_from_predecessors(node_id: str, + predecessors: Collection[_FancyDotWriterNode], + ) -> Tuple[_FancyDotWriterNode, + FrozenSet[Tuple[str, str]]]: + + new_edges: Set[Tuple[str, str]] = set() + + for pred in predecessors: + if isinstance(pred, PlainOldDotNode): + new_edges.add((pred.node_id, node_id)) + else: + assert isinstance(pred, NoShowNode) + + if new_edges: + return PlainOldDotNode(node_id), frozenset(new_edges) + else: + return NoShowNode(), frozenset() + + +class FancyDotWriter(CachedMapper[_FancyDotWriterNode]): + def __init__(self) -> None: + super().__init__() + self.vng = UniqueNameGenerator() + + self.node_decls: List[str] = [] + self.edges: Set[Tuple[str, str]] = set() + + def map_placeholder(self, expr: Placeholder) -> _FancyDotWriterNode: + node_decl = (f"{expr.name} [color={PLACEHOLDER_COLOR}, " + f"shape={PLACEHOLDER_SHAPE}]") + self.node_decls.append(node_decl) + return PlainOldDotNode(expr.name) + + def map_data_wrapper(self, expr: DataWrapper) -> _FancyDotWriterNode: + return NoShowNode() + + def map_index_lambda(self, expr: IndexLambda) -> _FancyDotWriterNode: + from pytato.raising import (index_lambda_to_high_level_op, + FullOp, BinaryOp, C99CallOp, + WhereOp, BroadcastOp, LogicalNotOp) + + hlo = index_lambda_to_high_level_op(expr) + + if isinstance(hlo, FullOp): + return NoShowNode() + elif isinstance(hlo, (BinaryOp, C99CallOp, WhereOp, + BroadcastOp, LogicalNotOp)): + node_id = self.vng("_pt_elem") + + node_decl = (f"{node_id}" + f' [label="",color={ELEMWISE_COLOR},' + f" shape={ELEMWISE_SHAPE}]") + else: + raise NotImplementedError(type(hlo)) + + ret_node, new_edges = _get_dot_node_from_predecessors( + node_id, + [self.rec(bnd) for bnd in expr.bindings.values()] + ) + + if new_edges: + self.node_decls.append(node_decl) + self.edges.update(new_edges) + + return ret_node + + def map_einsum(self, expr: Einsum) -> _FancyDotWriterNode: + from pytato.utils import get_einsum_subscript_str + + ensm_spec = get_einsum_subscript_str(expr) + node_id = self.vng("_pt_ensm") + spec = ensm_spec.replace("->", "→") + node_decl = (f'{node_id} [label="{spec}",' + f" color={EINSUM_COLOR}," + f" shape={EINSUM_SHAPE}," + " style=unfilled]") + + ret_node, new_edges = _get_dot_node_from_predecessors( + node_id, + [self.rec(arg) for arg in expr.args] + ) + + if new_edges: + self.node_decls.append(node_decl) + self.edges.update(new_edges) + + return ret_node + + def _map_stack_concat(self, + expr: Union[Stack, Concatenate]) -> _FancyDotWriterNode: + node_id = self.vng("_pt_stack_concat") + node_decl = (f'{node_id} [label="",' + f" color={STACK_CONCAT_COLOR}," + f" shape={STACK_CONCAT_SHAPE}]") + + ret_node, new_edges = _get_dot_node_from_predecessors( + node_id, + [self.rec(ary) for ary in expr.arrays] + ) + + if new_edges: + self.node_decls.append(node_decl) + self.edges.update(new_edges) + + return ret_node + + map_stack = _map_stack_concat + map_concatentate = _map_stack_concat + + def _map_index_remapping(self, + expr: IndexRemappingBase) -> _FancyDotWriterNode: + node_id = self.vng("_pt_idx_remap") + + node_decl = (f"{node_id}" + f' [label="",color={ELEMWISE_COLOR},' + f" shape={ELEMWISE_SHAPE}]") + ret_node, new_edges = _get_dot_node_from_predecessors( + node_id, + [self.rec(expr.array)] + ) + + if new_edges: + self.node_decls.append(node_decl) + self.edges.update(new_edges) + + return ret_node + + map_reshape = _map_index_remapping + map_roll = _map_index_remapping + map_axis_permutation = _map_index_remapping + map_basic_index = _map_index_remapping + + def _map_advanced_index(self, + expr: Union[AdvancedIndexInContiguousAxes, + AdvancedIndexInNoncontiguousAxes] + ) -> _FancyDotWriterNode: + node_id = self.vng("_pt_adv") + node_decl = (f"{node_id}" + f' [label="",color={INDIRECTION_COLOR},' + f" shape={INDIRECTION_SHAPE}]") + + ret_node, new_edges = _get_dot_node_from_predecessors( + node_id, + [self.rec(expr.array), + *[self.rec(idx) for idx in expr.indices if isinstance(idx, Array)]] + ) + + if new_edges: + self.node_decls.append(node_decl) + self.edges.update(new_edges) + + return ret_node + + map_contiguous_advanced_index = _map_advanced_index + map_non_contiguous_advanced_index = _map_advanced_index + + def map_dict_of_named_arrays(self, + expr: DictOfNamedArrays) -> _FancyDotWriterNode: + + for name, subexpr in expr._data.items(): + node_id = self.vng("_pt_out") + node_decl = (f"{node_id}" + f' [label="{name}",color={OUTPUT_COLOR},' + f" shape={OUTPUT_SHAPE}]") + rec_subexpr = self.rec(subexpr) + if isinstance(rec_subexpr, PlainOldDotNode): + self.node_decls.append(node_decl) + self.edges.add((rec_subexpr.node_id, node_id)) + else: + assert isinstance(rec_subexpr, NoShowNode) + + return NoShowNode() + + +def show_fancy_placeholder_data_flow(dag: Union[Array, DictOfNamedArrays], + **kwargs: Any) -> None: + """ + Visualizes the data-flow from the placeholders into outputs. + + :arg dag: The expression to be plotted. + :arg kwargs: Graphviz visualization options to be passed to + :func:`pytools.graphviz.show_dot`. + + .. note:: + + This is a heavily opinionated visualization of data-flow graph in + *dag*. Displaying all the information about the node is not the + priority. See :func:`pytato.show_dot_graph` that aims to be more + verbose. + """ + try: + from mako.template import Template + except ImportError: + raise RuntimeError("'show_fancy_placeholder_data_flow' requires" + " mako. Install as `pip install mako`.") + + if isinstance(dag, Array): + from pytato.array import make_dict_of_named_arrays + dag = make_dict_of_named_arrays({"_pt_out": dag}) + + assert isinstance(dag, DictOfNamedArrays) + + dot_writer = FancyDotWriter() + dot_writer(dag) + + dot_src = """ +digraph { + // set default properties + node[style=filled,fontsize=20] + edge[arrowhead=vee] + + // NODES + // ------------------------ + % for node_decl in node_decls: + ${node_decl} + % endfor + + // EDGES + // ------------------------ + % for edge in edges: + ${edge[0]} -> ${edge[1]} + % endfor +} + """ + + dot_code = (Template(dot_src, strict_undefined=True) + .render(node_decls=dot_writer.node_decls, + edges=dot_writer.edges)) + + from pytools.graphviz import show_dot + show_dot(dot_code, **kwargs) + +# vim:fdm=marker -- GitLab