diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index b836e38147362ee4e03ed08dc0b64208cb7f46a3..0008cea20859c0c2accdd260e022863e2055eb59 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -26,18 +26,19 @@ THE SOFTWARE. """ -import contextlib +from functools import partial import dataclasses import html -from typing import (TYPE_CHECKING, Callable, Dict, Union, Iterator, List, - Mapping, Hashable, Any, FrozenSet) +from typing import (TYPE_CHECKING, Callable, Dict, Tuple, Union, List, + Mapping, Any, FrozenSet, Set, Optional) from pytools import UniqueNameGenerator from pytools.tag import Tag -from pytools.codegen import CodeGenerator as CodeGeneratorBase from pytato.loopy import LoopyCall -from pytato.function import Call, NamedCallResult +from pytato.function import Call, FunctionDefinition, NamedCallResult +from pytato.tags import FunctionIdentifier +from pytools.codegen import remove_common_indentation from pytato.array import ( Array, DataWrapper, DictOfNamedArrays, IndexLambda, InputArgumentBase, @@ -45,7 +46,7 @@ from pytato.array import ( IndexBase) from pytato.codegen import normalize_outputs -from pytato.transform import CachedMapper, ArrayOrNames +from pytato.transform import CachedMapper, ArrayOrNames, InputGatherer from pytato.distributed.partition import ( DistributedGraphPartition, DistributedGraphPart, PartId) @@ -63,13 +64,88 @@ __doc__ = """ """ +# {{{ _DotEmitter + +@dataclasses.dataclass +class _SubgraphTree: + contents: Optional[List[str]] + subgraphs: Dict[str, _SubgraphTree] + + +class DotEmitter: + def __init__(self) -> None: + self.subgraph_to_lines: Dict[Tuple[str, ...], List[str]] = {} + + def __call__(self, subgraph_path: Tuple[str, ...], s: str) -> None: + line_list = self.subgraph_to_lines.setdefault(subgraph_path, []) + + if not s.strip(): + line_list.append("") + else: + if "\n" in s: + s = remove_common_indentation(s) + + for line in s.split("\n"): + line_list.append(line) + + def _get_subgraph_tree(self) -> _SubgraphTree: + subgraph_tree = _SubgraphTree(contents=None, subgraphs={}) + + def insert_into_subgraph_tree( + root: _SubgraphTree, path: Tuple[str, ...], contents: List[str] + ) -> None: + if not path: + assert root.contents is None + root.contents = contents + + else: + subgraph = root.subgraphs.setdefault( + path[0], + _SubgraphTree(contents=None, subgraphs={})) + + insert_into_subgraph_tree(subgraph, path[1:], contents) + + for sgp, lines in self.subgraph_to_lines.items(): + insert_into_subgraph_tree(subgraph_tree, sgp, lines) + + return subgraph_tree + + def generate(self) -> str: + result = ["digraph computation {"] + + indent_level = 1 + + def emit_subgraph(sg: _SubgraphTree) -> None: + nonlocal indent_level + + indent = (indent_level*4)*" " + if sg.contents: + for ln in sg.contents: + result.append(indent + ln) + + indent_level += 1 + for sg_name, sub_sg in sg.subgraphs.items(): + result.append(f"{indent}subgraph {sg_name} {{") + emit_subgraph(sub_sg) + result.append(f"{indent}" "}") + indent_level -= 1 + + emit_subgraph(self._get_subgraph_tree()) + + result.append("}") + + return "\n".join(result) + +# }}} + + # {{{ array -> dot node converter @dataclasses.dataclass -class DotNodeInfo: +class _DotNodeInfo: title: str fields: Dict[str, str] - edges: Dict[str, ArrayOrNames] + edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] def stringify_tags(tags: FrozenSet[Tag]) -> str: @@ -89,17 +165,18 @@ def stringify_shape(shape: ShapeType) -> str: class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): def __init__(self) -> None: super().__init__() - self.nodes: Dict[ArrayOrNames, DotNodeInfo] = {} + self.node_to_dot: Dict[ArrayOrNames, _DotNodeInfo] = {} + self.functions: Set[FunctionDefinition] = set() - def get_common_dot_info(self, expr: Array) -> DotNodeInfo: + def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: title = type(expr).__name__ fields = {"addr": hex(id(expr)), "shape": stringify_shape(expr.shape), "dtype": str(expr.dtype), "tags": stringify_tags(expr.tags)} - edges: Dict[str, ArrayOrNames] = {} - return DotNodeInfo(title, fields, edges) + edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} + return _DotNodeInfo(title, fields, edges) # type-ignore-reason: incompatible with supertype def handle_unsupported_array(self, # type: ignore[override] @@ -126,7 +203,7 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): else: info.fields[field] = str(attr) - self.nodes[expr] = info + self.node_to_dot[expr] = info def map_data_wrapper(self, expr: DataWrapper) -> None: info = self.get_common_dot_info(expr) @@ -138,7 +215,7 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): with np.printoptions(threshold=4, precision=2): info.fields["data"] = str(expr.data) - self.nodes[expr] = info + self.node_to_dot[expr] = info def map_index_lambda(self, expr: IndexLambda) -> None: info = self.get_common_dot_info(expr) @@ -148,7 +225,7 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): self.rec(val) info.edges[name] = val - self.nodes[expr] = info + self.node_to_dot[expr] = info def map_stack(self, expr: Stack) -> None: info = self.get_common_dot_info(expr) @@ -158,7 +235,7 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): self.rec(array) info.edges[str(i)] = array - self.nodes[expr] = info + self.node_to_dot[expr] = info map_concatenate = map_stack @@ -189,7 +266,7 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): self.rec(expr.array) info.edges["array"] = expr.array - self.nodes[expr] = info + self.node_to_dot[expr] = info map_contiguous_advanced_index = map_basic_index map_non_contiguous_advanced_index = map_basic_index @@ -202,27 +279,27 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): self.rec(val) info.edges[f"{iarg}: {access_descr}"] = val - self.nodes[expr] = info + self.node_to_dot[expr] = info def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: - edges: Dict[str, ArrayOrNames] = {} + edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} for name, val in expr._data.items(): edges[name] = val self.rec(val) - self.nodes[expr] = DotNodeInfo( + self.node_to_dot[expr] = _DotNodeInfo( title=type(expr).__name__, fields={}, edges=edges) def map_loopy_call(self, expr: LoopyCall) -> None: - edges: Dict[str, ArrayOrNames] = {} + edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} for name, arg in expr.bindings.items(): if isinstance(arg, Array): edges[name] = arg self.rec(arg) - self.nodes[expr] = DotNodeInfo( + self.node_to_dot[expr] = _DotNodeInfo( title=type(expr).__name__, fields={"addr": hex(id(expr)), "entrypoint": expr.entrypoint}, edges=edges) @@ -242,15 +319,19 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): info.fields["comm_tag"] = str(expr.send.comm_tag) - self.nodes[expr] = info + self.node_to_dot[expr] = info def map_call(self, expr: Call) -> None: + self.functions.add(expr.function) + for bnd in expr.bindings.values(): self.rec(bnd) - self.nodes[expr] = DotNodeInfo( + self.node_to_dot[expr] = _DotNodeInfo( title=expr.__class__.__name__, - edges=dict(expr.bindings), + edges={ + "": expr.function, + **expr.bindings}, fields={ "addr": hex(id(expr)), "tags": stringify_tags(expr.tags), @@ -259,30 +340,24 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): def map_named_call_result(self, expr: NamedCallResult) -> None: self.rec(expr._container) - self.nodes[expr] = DotNodeInfo( + self.node_to_dot[expr] = _DotNodeInfo( title=expr.__class__.__name__, edges={"": expr._container}, fields={"addr": hex(id(expr)), "name": expr.name}, ) +# }}} + def dot_escape(s: str) -> str: # "\" and HTML are significant in graphviz. - return html.escape(s.replace("\\", "\\\\")) - + return html.escape(s.replace("\\", "\\\\").replace(" ", "_")) -class DotEmitter(CodeGeneratorBase): - @contextlib.contextmanager - def block(self, name: str) -> Iterator[None]: - self(name + " {") - self.indent() - yield - self.dedent() - self("}") +# {{{ emit helpers -def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str], +def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, str], dot_node_id: str, color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' @@ -300,84 +375,189 @@ def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str], emit("%s [label=<%s> style=filled fillcolor=%s]" % (dot_node_id, table, color)) -def _emit_name_cluster(emit: DotEmitter, names: Mapping[str, ArrayOrNames], +def _emit_name_cluster( + emit: DotEmitter, subgraph_path: Tuple[str, ...], + names: Mapping[str, ArrayOrNames], array_to_id: Mapping[ArrayOrNames, str], id_gen: Callable[[str], str], label: str) -> None: edges = [] - with emit.block("subgraph cluster_%s" % label): - emit("node [shape=ellipse]") - emit('label="%s"' % label) + cluster_subgraph_path = subgraph_path + (f"cluster_{dot_escape(label)}",) + emit_cluster = partial(emit, cluster_subgraph_path) + emit_cluster("node [shape=ellipse]") + emit_cluster(f'label="{label}"') - for name, array in names.items(): - name_id = id_gen(label) - emit('%s [label="%s"]' % (name_id, dot_escape(name))) - array_id = array_to_id[array] - # Edges must be outside the cluster. - edges.append((name_id, array_id)) + for name, array in names.items(): + name_id = id_gen(dot_escape(name)) + emit_cluster('%s [label="%s"]' % (name_id, dot_escape(name))) + array_id = array_to_id[array] + # Edges must be outside the cluster. + edges.append((name_id, array_id)) for name_id, array_id in edges: - emit("%s -> %s" % (name_id, array_id)) - -# }}} + emit(subgraph_path, "%s -> %s" % (array_id, name_id)) -def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: - r"""Return a string in the `dot <https://graphviz.org>`_ language depicting the - graph of the computation of *result*. - - :arg result: Outputs of the computation (cf. - :func:`pytato.generate_loopy`). - """ - outputs: DictOfNamedArrays = normalize_outputs(result) - del result - - mapper = ArrayToDotNodeInfoMapper() - for elem in outputs._data.values(): - mapper(elem) - - nodes = mapper.nodes - +def _emit_function( + emitter: DotEmitter, subgraph_path: Tuple[str, ...], + id_gen: UniqueNameGenerator, + node_to_dot: Mapping[ArrayOrNames, _DotNodeInfo], + func_to_id: Mapping[FunctionDefinition, str], + outputs: Mapping[str, Array]) -> None: input_arrays: List[Array] = [] internal_arrays: List[ArrayOrNames] = [] array_to_id: Dict[ArrayOrNames, str] = {} - id_gen = UniqueNameGenerator() - for array in nodes: + emit = partial(emitter, subgraph_path) + for array in node_to_dot: array_to_id[array] = id_gen("array") if isinstance(array, InputArgumentBase): input_arrays.append(array) else: internal_arrays.append(array) - emit = DotEmitter() + # Emit inputs. + input_subgraph_path = subgraph_path + ("cluster_inputs",) + emit_input = partial(emitter, input_subgraph_path) + emit_input('label="Arguments"') + + for array in input_arrays: + _emit_array( + emit_input, + node_to_dot[array].title, + node_to_dot[array].fields, + array_to_id[array]) + + # Emit non-inputs. + for array in internal_arrays: + _emit_array(emit, + node_to_dot[array].title, + node_to_dot[array].fields, + array_to_id[array]) + + # Emit edges. + for array, node in node_to_dot.items(): + for label, tail_item in node.edges.items(): + head = array_to_id[array] + if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)): + tail = array_to_id[tail_item] + elif isinstance(tail_item, FunctionDefinition): + tail = func_to_id[tail_item] + else: + raise ValueError( + f"unexpected type of tail on edge: {type(tail_item)}") + + emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label))) - with emit.block("digraph computation"): - emit("node [shape=rectangle]") + # Emit output/namespace name mappings. + _emit_name_cluster( + emitter, subgraph_path, outputs, array_to_id, id_gen, label="Returns") - # Emit inputs. - with emit.block("subgraph cluster_Inputs"): - emit('label="Inputs"') - for array in input_arrays: - _emit_array(emit, - nodes[array].title, nodes[array].fields, array_to_id[array]) +# }}} - # Emit non-inputs. - for array in internal_arrays: - _emit_array(emit, - nodes[array].title, nodes[array].fields, array_to_id[array]) - # Emit edges. - for array, node in nodes.items(): - for label, tail_array in node.edges.items(): - tail = array_to_id[tail_array] - head = array_to_id[array] - emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label))) +# {{{ information gathering - # Emit output/namespace name mappings. - _emit_name_cluster(emit, outputs._data, array_to_id, id_gen, label="Outputs") +def _get_function_name(f: FunctionDefinition) -> Optional[str]: + func_id_tags = f.tags_of_type(FunctionIdentifier) + if func_id_tags: + func_id_tag, = func_id_tags + return str(func_id_tag.identifier) + else: + return None + + +def _gather_partition_node_information( + id_gen: UniqueNameGenerator, + partition: DistributedGraphPartition + ) -> Tuple[ + Mapping[PartId, Mapping[FunctionDefinition, str]], + Mapping[Tuple[PartId, Optional[FunctionDefinition]], + Mapping[ArrayOrNames, _DotNodeInfo]] + ]: + part_id_to_func_to_id: Dict[PartId, Dict[FunctionDefinition, str]] = {} + part_id_func_to_node_info: Dict[Tuple[PartId, Optional[FunctionDefinition]], + Dict[ArrayOrNames, _DotNodeInfo]] = {} + + for part in partition.parts.values(): + mapper = ArrayToDotNodeInfoMapper() + for out_name in part.output_names: + mapper(partition.name_to_output[out_name]) + + part_id_func_to_node_info[part.pid, None] = mapper.node_to_dot + part_id_to_func_to_id[part.pid] = {} - return emit.get() + # It is important that seen functions are emitted callee-first. + # (Otherwise function 'entry' nodes will get declared in the wrong + # cluster.) So use a data type that preserves order. + seen_functions: List[FunctionDefinition] = [] + + def gather_function_info(f: FunctionDefinition) -> None: + key = (part.pid, f) # noqa: B023 + if key in part_id_func_to_node_info: + return + + mapper = ArrayToDotNodeInfoMapper() + for elem in f.returns.values(): + mapper(elem) + + part_id_func_to_node_info[key] = mapper.node_to_dot + + for subfunc in mapper.functions: + gather_function_info(subfunc) + + if f not in seen_functions: # noqa: B023 + seen_functions.append(f) # noqa: B023 + + for f in mapper.functions: + gather_function_info(f) + + # Again, important to preserve function order. Here we're relying + # on dicts to preserve order. + for f in seen_functions: + func_name = _get_function_name(f) + if func_name is not None: + fid = id_gen(dot_escape(func_name)) + else: + fid = id_gen("func") + + part_id_to_func_to_id.setdefault(part.pid, {})[f] = fid + + return part_id_to_func_to_id, part_id_func_to_node_info + +# }}} + + +def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: + r"""Return a string in the `dot <https://graphviz.org>`_ language depicting the + graph of the computation of *result*. + + :arg result: Outputs of the computation (cf. + :func:`pytato.generate_loopy`). + """ + + outputs: DictOfNamedArrays = normalize_outputs(result) + + return get_dot_graph_from_partition( + DistributedGraphPartition( + parts={ + None: DistributedGraphPart( + pid=None, + needed_pids=frozenset(), + user_input_names=frozenset( + expr.name + for expr in InputGatherer()(outputs) + if isinstance(expr, Placeholder) + ), + partition_input_names=frozenset(), + output_names=frozenset(outputs.keys()), + name_to_recv_node={}, + name_to_send_nodes={}, + ) + }, + name_to_output=outputs._data, + overall_output_names=tuple(outputs), + )) def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: @@ -386,170 +566,226 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: :arg partition: Outputs of :func:`~pytato.find_distributed_partition`. """ - # Maps each partition to a dict of its arrays with the node info - part_id_to_node_info: Dict[Hashable, Dict[ArrayOrNames, DotNodeInfo]] = {} + id_gen = UniqueNameGenerator() - for part in partition.parts.values(): - mapper = ArrayToDotNodeInfoMapper() - for out_name in part.output_names: - mapper(partition.name_to_output[out_name]) + # {{{ gather up node info, per partition and per function - part_id_to_node_info[part.pid] = mapper.nodes + # The "None" function is the body of the partition. - id_gen = UniqueNameGenerator() + part_id_to_func_to_id, part_id_func_to_node_info = \ + _gather_partition_node_information(id_gen, partition) - emit = DotEmitter() + # }}} + + emitter = DotEmitter() + emit_root = partial(emitter, ()) emitted_placeholders = set() - with emit.block("digraph computation"): - emit("node [shape=rectangle]") - placeholder_to_id: Dict[ArrayOrNames, str] = {} - part_id_to_array_to_id: Dict[PartId, Dict[ArrayOrNames, str]] = {} - - # First pass: generate names for all nodes - for part in partition.parts.values(): - array_to_id = {} - for array, _ in part_id_to_node_info[part.pid].items(): - if isinstance(array, Placeholder): - # Placeholders are only emitted once - if array in placeholder_to_id: - node_id = placeholder_to_id[array] - else: - node_id = id_gen("array") - placeholder_to_id[array] = node_id + emit_root("node [shape=rectangle]") + + placeholder_to_id: Dict[ArrayOrNames, str] = {} + part_id_to_array_to_id: Dict[PartId, Dict[ArrayOrNames, str]] = {} + + part_id_to_id = {pid: dot_escape(str(pid)) for pid in partition.parts} + assert len(set(part_id_to_id.values())) == len(partition.parts) + + # {{{ generate names for all nodes in the root/None function + + for part in partition.parts.values(): + array_to_id = {} + for array in part_id_func_to_node_info[part.pid, None].keys(): + if isinstance(array, Placeholder): + # Placeholders are only emitted once + if array in placeholder_to_id: + node_id = placeholder_to_id[array] else: node_id = id_gen("array") - array_to_id[array] = node_id - part_id_to_array_to_id[part.pid] = array_to_id - - # Second pass: emit the graph. - for part in partition.parts.values(): - array_to_id = part_id_to_array_to_id[part.pid] - - # {{{ emit receives nodes if distributed - - if isinstance(part, DistributedGraphPart): - part_dist_recv_var_name_to_node_id = {} - for name, recv in ( - part.name_to_recv_node.items()): - node_id = id_gen("recv") - _emit_array(emit, "DistributedRecv", { - "shape": stringify_shape(recv.shape), - "dtype": str(recv.dtype), - "src_rank": str(recv.src_rank), - "comm_tag": str(recv.comm_tag), - }, node_id) + placeholder_to_id[array] = node_id + else: + node_id = id_gen("array") + array_to_id[array] = node_id + part_id_to_array_to_id[part.pid] = array_to_id + + # }}} + + # {{{ emit the graph + + for part in partition.parts.values(): + array_to_id = part_id_to_array_to_id[part.pid] - part_dist_recv_var_name_to_node_id[name] = node_id + is_trivial_partition = part.pid is None and len(partition.parts) == 1 + if is_trivial_partition: + part_subgraph_path: Tuple[str, ...] = () + else: + part_subgraph_path = (f"cluster_{part_id_to_id[part.pid]}",) + + emit_part = partial(emitter, part_subgraph_path) + + if not is_trivial_partition: + emit_part("style=dashed") + emit_part(f'label="{part.pid}"') + + # {{{ emit functions + + # It is important that seen functions are emitted callee-first. + # Here we're relying on the part_id_to_func_to_id dict to preserve order. + + for func, fid in part_id_to_func_to_id[part.pid].items(): + func_subgraph_path = part_subgraph_path + (f"cluster_{fid}",) + label = _get_function_name(func) or fid + + emitter(func_subgraph_path, f'label="{label}"') + emitter(func_subgraph_path, f'{fid} [label="{label}",shape="ellipse"]') + + _emit_function(emitter, func_subgraph_path, + id_gen, part_id_func_to_node_info[part.pid, func], + part_id_to_func_to_id[part.pid], + func.returns) + + # }}} + + # {{{ emit receives nodes + + part_dist_recv_var_name_to_node_id = {} + for name, recv in ( + part.name_to_recv_node.items()): + node_id = id_gen("recv") + _emit_array(emit_part, "DistributedRecv", { + "shape": stringify_shape(recv.shape), + "dtype": str(recv.dtype), + "src_rank": str(recv.src_rank), + "comm_tag": str(recv.comm_tag), + }, node_id) + + part_dist_recv_var_name_to_node_id[name] = node_id + + # }}} + + part_node_to_info = part_id_func_to_node_info[part.pid, None] + input_arrays: List[Array] = [] + internal_arrays: List[ArrayOrNames] = [] + + for array in part_node_to_info.keys(): + if isinstance(array, InputArgumentBase): + input_arrays.append(array) else: - part_dist_recv_var_name_to_node_id = {} + internal_arrays.append(array) - # }}} + # {{{ emit inputs - part_node_to_info = part_id_to_node_info[part.pid] - input_arrays: List[Array] = [] - internal_arrays: List[ArrayOrNames] = [] + # Placeholders are unique, i.e. the same Placeholder object may be + # shared among partitions. Therefore, they should not live inside + # the (dot) subgraph, otherwise they would be forced into multiple + # subgraphs. - for array in part_node_to_info.keys(): - if isinstance(array, InputArgumentBase): - input_arrays.append(array) + for array in input_arrays: + if not isinstance(array, Placeholder): + _emit_array(emit_part, + part_node_to_info[array].title, + part_node_to_info[array].fields, + array_to_id[array], "deepskyblue") + else: + # Is a Placeholder + if array in emitted_placeholders: + continue + + _emit_array(emit_root, + part_node_to_info[array].title, + part_node_to_info[array].fields, + array_to_id[array], "deepskyblue") + + # Emit cross-partition edges + if array.name in part_dist_recv_var_name_to_node_id: + tgt = part_dist_recv_var_name_to_node_id[array.name] + emit_root(f"{tgt} -> {array_to_id[array]} [style=dotted]") + emitted_placeholders.add(array) + elif array.name in part.user_input_names: + # no arrows for these + pass else: - internal_arrays.append(array) - - # {{{ emit inputs - - # Placeholders are unique, i.e. the same Placeholder object may be - # shared among partitions. Therefore, they should not live inside - # the (dot) subgraph, otherwise they would be forced into multiple - # subgraphs. - - for array in input_arrays: - # Non-Placeholders are emitted *inside* their subgraphs below. - if isinstance(array, Placeholder): - if array not in emitted_placeholders: - _emit_array(emit, - part_node_to_info[array].title, - part_node_to_info[array].fields, - array_to_id[array], "deepskyblue") - - # Emit cross-partition edges - if array.name in part_dist_recv_var_name_to_node_id: - tgt = part_dist_recv_var_name_to_node_id[array.name] - emit(f"{tgt} -> {array_to_id[array]} [style=dotted]") - emitted_placeholders.add(array) - elif array.name in part.user_input_names: - # These are placeholders for external input. They - # are cleanly associated with a single partition - # and thus emitted below. - pass - else: - # placeholder for a value from a different partition - computing_pid = None - for other_part in partition.parts.values(): - if array.name in other_part.output_names: - computing_pid = other_part.pid - break - assert computing_pid is not None - tgt = part_id_to_array_to_id[computing_pid][ - partition.name_to_output[array.name]] - emit(f"{tgt} -> {array_to_id[array]} [style=dashed]") - emitted_placeholders.add(array) - - # }}} - - with emit.block(f'subgraph "cluster_part_{part.pid}"'): - emit("style=dashed") - emit(f'label="{part.pid}"') - - for array in input_arrays: - if (not isinstance(array, Placeholder) - or array.name in part.user_input_names): - _emit_array(emit, - part_node_to_info[array].title, - part_node_to_info[array].fields, - array_to_id[array], "deepskyblue") - - # Emit internal nodes - for array in internal_arrays: - _emit_array(emit, - part_node_to_info[array].title, - part_node_to_info[array].fields, - array_to_id[array]) - - # {{{ emit send nodes if distributed - - deferred_send_edges = [] - if isinstance(part, DistributedGraphPart): - for name, sends in ( - part.name_to_send_nodes.items()): - for send in sends: - node_id = id_gen("send") - _emit_array(emit, "DistributedSend", { - "dest_rank": str(send.dest_rank), - "comm_tag": str(send.comm_tag), - }, node_id) - - deferred_send_edges.append( - f"{array_to_id[send.data]} -> {node_id}" - f'[style=dotted, label="{dot_escape(name)}"]') - - # }}} - - # If an edge is emitted in a subgraph, it drags its nodes into the - # subgraph, too. Not what we want. - for edge in deferred_send_edges: - emit(edge) - - # Emit intra-partition edges - for array, node in part_node_to_info.items(): - for label, tail_array in node.edges.items(): - tail = array_to_id[tail_array] - head = array_to_id[array] - emit('%s -> %s [label="%s"]' % - (tail, head, dot_escape(label))) - - return emit.get() + # placeholder for a value from a different partition + computing_pid = None + for other_part in partition.parts.values(): + if array.name in other_part.output_names: + computing_pid = other_part.pid + break + assert computing_pid is not None + tgt = part_id_to_array_to_id[computing_pid][ + partition.name_to_output[array.name]] + emit_root(f"{tgt} -> {array_to_id[array]} [style=dashed]") + emitted_placeholders.add(array) + + # }}} + + # Emit internal nodes + for array in internal_arrays: + _emit_array(emit_part, + part_node_to_info[array].title, + part_node_to_info[array].fields, + array_to_id[array]) + + # {{{ emit send nodes if distributed + + if isinstance(part, DistributedGraphPart): + for name, sends in part.name_to_send_nodes.items(): + for send in sends: + node_id = id_gen("send") + _emit_array(emit_part, "DistributedSend", { + "dest_rank": str(send.dest_rank), + "comm_tag": str(send.comm_tag), + }, node_id) + + # If an edge is emitted in a subgraph, it drags its + # nodes into the subgraph, too. Not what we want. + emit_root( + f"{array_to_id[send.data]} -> {node_id}" + f'[style=dotted, label="{dot_escape(name)}"]') + + # }}} + + # Emit intra-partition edges + for array, node in part_node_to_info.items(): + for label, tail_item in node.edges.items(): + head = array_to_id[array] + + if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)): + tail = array_to_id[tail_item] + elif isinstance(tail_item, FunctionDefinition): + tail = part_id_to_func_to_id[part.pid][tail_item] + else: + raise ValueError( + f"unexpected type of tail on edge: {type(tail_item)}") + + emit_root('%s -> %s [label="%s"]' % + (tail, head, dot_escape(label))) + + _emit_name_cluster( + emitter, part_subgraph_path, + {name: partition.name_to_output[name] for name in part.output_names}, + array_to_id, id_gen, "Part outputs") + + # }}} + + # Arrays may occur in multiple partitions, they get drawn separately anyhow + # (unless they're Placeholders). Don't be tempted to use + # combined_array_to_id everywhere. + + # {{{ draw overall outputs + + combined_array_to_id: Dict[ArrayOrNames, str] = {} + for part_id in partition.parts.keys(): + combined_array_to_id.update(part_id_to_array_to_id[part_id]) + + _emit_name_cluster( + emitter, (), + {name: partition.name_to_output[name] + for name in partition.overall_output_names}, + combined_array_to_id, id_gen, "Overall outputs") + + # }}} + + return emitter.generate() def show_dot_graph(result: Union[str, Array, DictOfNamedArrays, diff --git a/test/test_codegen.py b/test/test_codegen.py index e02fbc85c17ea1ce086d690526747b872350dba7..b6d0bba8b4eccde3f9cee9a81c98ab9eaf14d8a5 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1854,15 +1854,16 @@ def test_pad(ctx_factory): np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array) -def test_function_call(ctx_factory): +def test_function_call(ctx_factory, visualize=False): + from functools import partial cl_ctx = ctx_factory() cq = cl.CommandQueue(cl_ctx) def f(x): return 2*x - def g(x): - return 2*x, 3*x + def g(tracer, x): + return tracer(f, x), 3*x def h(x, y): return {"twice": 2*x+y, "thrice": 3*x+y} @@ -1870,7 +1871,7 @@ def test_function_call(ctx_factory): def build_expression(tracer): x = pt.arange(500, dtype=np.float32) twice_x = tracer(f, x) - twice_x_2, thrice_x_2 = tracer(g, x) + twice_x_2, thrice_x_2 = tracer(partial(g, tracer), x) result = tracer(h, x, 2*x) twice_x_3 = result["twice"] @@ -1881,13 +1882,19 @@ def test_function_call(ctx_factory): "baz": 65 * twice_x, "quux": 7 * twice_x_2} - result1 = pt.tag_all_calls_to_be_inlined( + result_with_functions = pt.tag_all_calls_to_be_inlined( pt.make_dict_of_named_arrays(build_expression(pt.trace_call))) - result2 = pt.make_dict_of_named_arrays( + result_without_functions = pt.make_dict_of_named_arrays( build_expression(lambda fn, *args: fn(*args))) - _, outputs = pt.generate_loopy(result1)(cq, out_host=True) - _, expected = pt.generate_loopy(result2)(cq, out_host=True) + # test that visualizing graphs with functions works + dot = pt.get_dot_graph(result_with_functions) + + if visualize: + pt.show_dot_graph(dot) + + _, outputs = pt.generate_loopy(result_with_functions)(cq, out_host=True) + _, expected = pt.generate_loopy(result_without_functions)(cq, out_host=True) assert len(outputs) == len(expected)