diff --git a/pytato/visualization.py b/pytato/visualization.py index 774742ce64173a2a449586b3a0429ff633e01410..976cceb51a9d284b044a1adbb1f413e1ddb270ed 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -54,6 +54,7 @@ __doc__ = """ @dataclasses.dataclass class DotNodeInfo: title: str + addr: str fields: Dict[str, str] edges: Dict[str, Array] @@ -75,11 +76,12 @@ def stringify_shape(shape: ShapeType) -> str: class ArrayToDotNodeInfoMapper(pytato.transform.Mapper): def get_common_dot_info(self, expr: Array) -> DotNodeInfo: title = type(expr).__name__ + addr = hex(id(expr)) fields = dict(shape=stringify_shape(expr.shape), dtype=str(expr.dtype), tags=stringify_tags(expr.tags)) edges: Dict[str, Array] = {} - return DotNodeInfo(title, fields, edges) + return DotNodeInfo(title, addr, fields, edges) def handle_unsupported_array(self, expr: Array, # type: ignore nodes: Dict[Array, DotNodeInfo]) -> None: @@ -166,6 +168,9 @@ def _emit_array(emit: DotEmitter, info: DotNodeInfo, id: str) -> None: rows = ['<tr><td colspan="2" %s>%s</td></tr>' % (td_attrib, dot_escape(info.title))] + rows.append("<tr><td %s>%s:</td><td %s>%s</td></tr>" + % (td_attrib, "addr", td_attrib, info.addr)) + for name, field in info.fields.items(): rows.append( "<tr><td %s>%s:</td><td %s>%s</td></tr>"