From a00f8b4f1a1de19617c7cb3e2ba6000d30898377 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 30 Jan 2023 10:46:41 -0600 Subject: [PATCH] make get_dot_graph_from_partition work with duplicated computation --- pytato/visualization.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/pytato/visualization.py b/pytato/visualization.py index 3eec7b8..4bdb103 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -46,7 +46,7 @@ from pytato.array import ( from pytato.codegen import normalize_outputs from pytato.transform import CachedMapper, ArrayOrNames -from pytato.partition import GraphPartition +from pytato.partition import GraphPartition, PartId from pytato.distributed.partition import DistributedGraphPart if TYPE_CHECKING: @@ -383,15 +383,29 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: with emit.block("digraph computation"): emit("node [shape=rectangle]") - array_to_id: Dict[ArrayOrNames, str] = {} + placeholder_to_id: Dict[ArrayOrNames, str] = {} + part_id_to_array_to_id: Dict[PartId, Dict[ArrayOrNames, str]] = {} # First pass: generate names for all nodes for part in partition.parts.values(): + array_to_id = {} for array, _ in part_id_to_node_info[part.pid].items(): - array_to_id[array] = id_gen("array") + if isinstance(array, Placeholder): + # Placeholders are only emitted once + if array in placeholder_to_id: + node_id = placeholder_to_id[array] + else: + node_id = id_gen("array") + placeholder_to_id[array] = node_id + else: + node_id = id_gen("array") + array_to_id[array] = node_id + part_id_to_array_to_id[part.pid] = array_to_id # Second pass: emit the graph. for part in partition.parts.values(): + array_to_id = part_id_to_array_to_id[part.pid] + # {{{ emit receives nodes if distributed if isinstance(part, DistributedGraphPart): @@ -416,7 +430,7 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: input_arrays: List[Array] = [] internal_arrays: List[ArrayOrNames] = [] - for array, _ in part_node_to_info.items(): + for array in part_node_to_info.keys(): if isinstance(array, InputArgumentBase): input_arrays.append(array) else: @@ -450,7 +464,13 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: pass else: # placeholder for a value from a different partition - tgt = array_to_id[ + computing_pid = None + for other_part in partition.parts.values(): + if array.name in other_part.output_names: + computing_pid = other_part.pid + break + assert computing_pid is not None + tgt = part_id_to_array_to_id[computing_pid][ partition.var_name_to_result[array.name]] emit(f"{tgt} -> {array_to_id[array]} [style=dashed]") emitted_placeholders.add(array) -- GitLab