diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py
index b836e38147362ee4e03ed08dc0b64208cb7f46a3..0008cea20859c0c2accdd260e022863e2055eb59 100644
--- a/pytato/visualization/dot.py
+++ b/pytato/visualization/dot.py
@@ -26,18 +26,19 @@ THE SOFTWARE.
 """
 
 
-import contextlib
+from functools import partial
 import dataclasses
 import html
 
-from typing import (TYPE_CHECKING, Callable, Dict, Union, Iterator, List,
-        Mapping, Hashable, Any, FrozenSet)
+from typing import (TYPE_CHECKING, Callable, Dict, Tuple, Union, List,
+        Mapping, Any, FrozenSet, Set, Optional)
 
 from pytools import UniqueNameGenerator
 from pytools.tag import Tag
-from pytools.codegen import CodeGenerator as CodeGeneratorBase
 from pytato.loopy import LoopyCall
-from pytato.function import Call, NamedCallResult
+from pytato.function import Call, FunctionDefinition, NamedCallResult
+from pytato.tags import FunctionIdentifier
+from pytools.codegen import remove_common_indentation
 
 from pytato.array import (
         Array, DataWrapper, DictOfNamedArrays, IndexLambda, InputArgumentBase,
@@ -45,7 +46,7 @@ from pytato.array import (
         IndexBase)
 
 from pytato.codegen import normalize_outputs
-from pytato.transform import CachedMapper, ArrayOrNames
+from pytato.transform import CachedMapper, ArrayOrNames, InputGatherer
 
 from pytato.distributed.partition import (
         DistributedGraphPartition, DistributedGraphPart, PartId)
@@ -63,13 +64,88 @@ __doc__ = """
 """
 
 
+# {{{ _DotEmitter
+
+@dataclasses.dataclass
+class _SubgraphTree:
+    contents: Optional[List[str]]
+    subgraphs: Dict[str, _SubgraphTree]
+
+
+class DotEmitter:
+    def __init__(self) -> None:
+        self.subgraph_to_lines: Dict[Tuple[str, ...], List[str]] = {}
+
+    def __call__(self, subgraph_path: Tuple[str, ...], s: str) -> None:
+        line_list = self.subgraph_to_lines.setdefault(subgraph_path, [])
+
+        if not s.strip():
+            line_list.append("")
+        else:
+            if "\n" in s:
+                s = remove_common_indentation(s)
+
+            for line in s.split("\n"):
+                line_list.append(line)
+
+    def _get_subgraph_tree(self) -> _SubgraphTree:
+        subgraph_tree = _SubgraphTree(contents=None, subgraphs={})
+
+        def insert_into_subgraph_tree(
+                root: _SubgraphTree, path: Tuple[str, ...], contents: List[str]
+                ) -> None:
+            if not path:
+                assert root.contents is None
+                root.contents = contents
+
+            else:
+                subgraph = root.subgraphs.setdefault(
+                        path[0],
+                        _SubgraphTree(contents=None, subgraphs={}))
+
+                insert_into_subgraph_tree(subgraph, path[1:], contents)
+
+        for sgp, lines in self.subgraph_to_lines.items():
+            insert_into_subgraph_tree(subgraph_tree, sgp, lines)
+
+        return subgraph_tree
+
+    def generate(self) -> str:
+        result = ["digraph computation {"]
+
+        indent_level = 1
+
+        def emit_subgraph(sg: _SubgraphTree) -> None:
+            nonlocal indent_level
+
+            indent = (indent_level*4)*" "
+            if sg.contents:
+                for ln in sg.contents:
+                    result.append(indent + ln)
+
+            indent_level += 1
+            for sg_name, sub_sg in sg.subgraphs.items():
+                result.append(f"{indent}subgraph {sg_name} {{")
+                emit_subgraph(sub_sg)
+                result.append(f"{indent}" "}")
+            indent_level -= 1
+
+        emit_subgraph(self._get_subgraph_tree())
+
+        result.append("}")
+
+        return "\n".join(result)
+
+# }}}
+
+
 # {{{ array -> dot node converter
 
 @dataclasses.dataclass
-class DotNodeInfo:
+class _DotNodeInfo:
     title: str
     fields: Dict[str, str]
-    edges: Dict[str, ArrayOrNames]
+    edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]]
 
 
 def stringify_tags(tags: FrozenSet[Tag]) -> str:
@@ -89,17 +165,18 @@ def stringify_shape(shape: ShapeType) -> str:
 class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
     def __init__(self) -> None:
         super().__init__()
-        self.nodes: Dict[ArrayOrNames, DotNodeInfo] = {}
+        self.node_to_dot: Dict[ArrayOrNames, _DotNodeInfo] = {}
+        self.functions: Set[FunctionDefinition] = set()
 
-    def get_common_dot_info(self, expr: Array) -> DotNodeInfo:
+    def get_common_dot_info(self, expr: Array) -> _DotNodeInfo:
         title = type(expr).__name__
         fields = {"addr": hex(id(expr)),
                 "shape": stringify_shape(expr.shape),
                 "dtype": str(expr.dtype),
                 "tags": stringify_tags(expr.tags)}
 
-        edges: Dict[str, ArrayOrNames] = {}
-        return DotNodeInfo(title, fields, edges)
+        edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {}
+        return _DotNodeInfo(title, fields, edges)
 
     # type-ignore-reason: incompatible with supertype
     def handle_unsupported_array(self,  # type: ignore[override]
@@ -126,7 +203,7 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
             else:
                 info.fields[field] = str(attr)
 
-        self.nodes[expr] = info
+        self.node_to_dot[expr] = info
 
     def map_data_wrapper(self, expr: DataWrapper) -> None:
         info = self.get_common_dot_info(expr)
@@ -138,7 +215,7 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
         with np.printoptions(threshold=4, precision=2):
             info.fields["data"] = str(expr.data)
 
-        self.nodes[expr] = info
+        self.node_to_dot[expr] = info
 
     def map_index_lambda(self, expr: IndexLambda) -> None:
         info = self.get_common_dot_info(expr)
@@ -148,7 +225,7 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
             self.rec(val)
             info.edges[name] = val
 
-        self.nodes[expr] = info
+        self.node_to_dot[expr] = info
 
     def map_stack(self, expr: Stack) -> None:
         info = self.get_common_dot_info(expr)
@@ -158,7 +235,7 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
             self.rec(array)
             info.edges[str(i)] = array
 
-        self.nodes[expr] = info
+        self.node_to_dot[expr] = info
 
     map_concatenate = map_stack
 
@@ -189,7 +266,7 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
         self.rec(expr.array)
         info.edges["array"] = expr.array
 
-        self.nodes[expr] = info
+        self.node_to_dot[expr] = info
 
     map_contiguous_advanced_index = map_basic_index
     map_non_contiguous_advanced_index = map_basic_index
@@ -202,27 +279,27 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
             self.rec(val)
             info.edges[f"{iarg}: {access_descr}"] = val
 
-        self.nodes[expr] = info
+        self.node_to_dot[expr] = info
 
     def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None:
-        edges: Dict[str, ArrayOrNames] = {}
+        edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {}
         for name, val in expr._data.items():
             edges[name] = val
             self.rec(val)
 
-        self.nodes[expr] = DotNodeInfo(
+        self.node_to_dot[expr] = _DotNodeInfo(
                 title=type(expr).__name__,
                 fields={},
                 edges=edges)
 
     def map_loopy_call(self, expr: LoopyCall) -> None:
-        edges: Dict[str, ArrayOrNames] = {}
+        edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {}
         for name, arg in expr.bindings.items():
             if isinstance(arg, Array):
                 edges[name] = arg
                 self.rec(arg)
 
-        self.nodes[expr] = DotNodeInfo(
+        self.node_to_dot[expr] = _DotNodeInfo(
                 title=type(expr).__name__,
                 fields={"addr": hex(id(expr)), "entrypoint": expr.entrypoint},
                 edges=edges)
@@ -242,15 +319,19 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
 
         info.fields["comm_tag"] = str(expr.send.comm_tag)
 
-        self.nodes[expr] = info
+        self.node_to_dot[expr] = info
 
     def map_call(self, expr: Call) -> None:
+        self.functions.add(expr.function)
+
         for bnd in expr.bindings.values():
             self.rec(bnd)
 
-        self.nodes[expr] = DotNodeInfo(
+        self.node_to_dot[expr] = _DotNodeInfo(
             title=expr.__class__.__name__,
-            edges=dict(expr.bindings),
+            edges={
+                "": expr.function,
+                **expr.bindings},
             fields={
                 "addr": hex(id(expr)),
                 "tags": stringify_tags(expr.tags),
@@ -259,30 +340,24 @@ class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
 
     def map_named_call_result(self, expr: NamedCallResult) -> None:
         self.rec(expr._container)
-        self.nodes[expr] = DotNodeInfo(
+        self.node_to_dot[expr] = _DotNodeInfo(
                 title=expr.__class__.__name__,
                 edges={"": expr._container},
                 fields={"addr": hex(id(expr)),
                         "name": expr.name},
         )
 
+# }}}
+
 
 def dot_escape(s: str) -> str:
     # "\" and HTML are significant in graphviz.
-    return html.escape(s.replace("\\", "\\\\"))
-
+    return html.escape(s.replace("\\", "\\\\").replace(" ", "_"))
 
-class DotEmitter(CodeGeneratorBase):
-    @contextlib.contextmanager
-    def block(self, name: str) -> Iterator[None]:
-        self(name + " {")
-        self.indent()
-        yield
-        self.dedent()
-        self("}")
 
+# {{{ emit helpers
 
-def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str],
+def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, str],
         dot_node_id: str, color: str = "white") -> None:
     td_attrib = 'border="0"'
     table_attrib = 'border="0" cellborder="1" cellspacing="0"'
@@ -300,84 +375,189 @@ def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str],
     emit("%s [label=<%s> style=filled fillcolor=%s]" % (dot_node_id, table, color))
 
 
-def _emit_name_cluster(emit: DotEmitter, names: Mapping[str, ArrayOrNames],
+def _emit_name_cluster(
+        emit: DotEmitter, subgraph_path: Tuple[str, ...],
+        names: Mapping[str, ArrayOrNames],
         array_to_id: Mapping[ArrayOrNames, str], id_gen: Callable[[str], str],
         label: str) -> None:
     edges = []
 
-    with emit.block("subgraph cluster_%s" % label):
-        emit("node [shape=ellipse]")
-        emit('label="%s"' % label)
+    cluster_subgraph_path = subgraph_path + (f"cluster_{dot_escape(label)}",)
+    emit_cluster = partial(emit, cluster_subgraph_path)
+    emit_cluster("node [shape=ellipse]")
+    emit_cluster(f'label="{label}"')
 
-        for name, array in names.items():
-            name_id = id_gen(label)
-            emit('%s [label="%s"]' % (name_id, dot_escape(name)))
-            array_id = array_to_id[array]
-            # Edges must be outside the cluster.
-            edges.append((name_id, array_id))
+    for name, array in names.items():
+        name_id = id_gen(dot_escape(name))
+        emit_cluster('%s [label="%s"]' % (name_id, dot_escape(name)))
+        array_id = array_to_id[array]
+        # Edges must be outside the cluster.
+        edges.append((name_id, array_id))
 
     for name_id, array_id in edges:
-        emit("%s -> %s" % (name_id, array_id))
-
-# }}}
+        emit(subgraph_path, "%s -> %s" % (array_id, name_id))
 
 
-def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str:
-    r"""Return a string in the `dot <https://graphviz.org>`_ language depicting the
-    graph of the computation of *result*.
-
-    :arg result: Outputs of the computation (cf.
-        :func:`pytato.generate_loopy`).
-    """
-    outputs: DictOfNamedArrays = normalize_outputs(result)
-    del result
-
-    mapper = ArrayToDotNodeInfoMapper()
-    for elem in outputs._data.values():
-        mapper(elem)
-
-    nodes = mapper.nodes
-
+def _emit_function(
+        emitter: DotEmitter, subgraph_path: Tuple[str, ...],
+        id_gen: UniqueNameGenerator,
+        node_to_dot: Mapping[ArrayOrNames, _DotNodeInfo],
+        func_to_id: Mapping[FunctionDefinition, str],
+        outputs: Mapping[str, Array]) -> None:
     input_arrays: List[Array] = []
     internal_arrays: List[ArrayOrNames] = []
     array_to_id: Dict[ArrayOrNames, str] = {}
 
-    id_gen = UniqueNameGenerator()
-    for array in nodes:
+    emit = partial(emitter, subgraph_path)
+    for array in node_to_dot:
         array_to_id[array] = id_gen("array")
         if isinstance(array, InputArgumentBase):
             input_arrays.append(array)
         else:
             internal_arrays.append(array)
 
-    emit = DotEmitter()
+    # Emit inputs.
+    input_subgraph_path = subgraph_path + ("cluster_inputs",)
+    emit_input = partial(emitter, input_subgraph_path)
+    emit_input('label="Arguments"')
+
+    for array in input_arrays:
+        _emit_array(
+                emit_input,
+                node_to_dot[array].title,
+                node_to_dot[array].fields,
+                array_to_id[array])
+
+    # Emit non-inputs.
+    for array in internal_arrays:
+        _emit_array(emit,
+                    node_to_dot[array].title,
+                    node_to_dot[array].fields,
+                    array_to_id[array])
+
+    # Emit edges.
+    for array, node in node_to_dot.items():
+        for label, tail_item in node.edges.items():
+            head = array_to_id[array]
+            if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)):
+                tail = array_to_id[tail_item]
+            elif isinstance(tail_item, FunctionDefinition):
+                tail = func_to_id[tail_item]
+            else:
+                raise ValueError(
+                        f"unexpected type of tail on edge: {type(tail_item)}")
+
+            emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label)))
 
-    with emit.block("digraph computation"):
-        emit("node [shape=rectangle]")
+    # Emit output/namespace name mappings.
+    _emit_name_cluster(
+            emitter, subgraph_path, outputs, array_to_id, id_gen, label="Returns")
 
-        # Emit inputs.
-        with emit.block("subgraph cluster_Inputs"):
-            emit('label="Inputs"')
-            for array in input_arrays:
-                _emit_array(emit,
-                        nodes[array].title, nodes[array].fields, array_to_id[array])
+# }}}
 
-        # Emit non-inputs.
-        for array in internal_arrays:
-            _emit_array(emit,
-                    nodes[array].title, nodes[array].fields, array_to_id[array])
 
-        # Emit edges.
-        for array, node in nodes.items():
-            for label, tail_array in node.edges.items():
-                tail = array_to_id[tail_array]
-                head = array_to_id[array]
-                emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label)))
+# {{{ information gathering
 
-        # Emit output/namespace name mappings.
-        _emit_name_cluster(emit, outputs._data, array_to_id, id_gen, label="Outputs")
+def _get_function_name(f: FunctionDefinition) -> Optional[str]:
+    func_id_tags = f.tags_of_type(FunctionIdentifier)
+    if func_id_tags:
+        func_id_tag, = func_id_tags
+        return str(func_id_tag.identifier)
+    else:
+        return None
+
+
+def _gather_partition_node_information(
+        id_gen: UniqueNameGenerator,
+        partition: DistributedGraphPartition
+        ) -> Tuple[
+                Mapping[PartId, Mapping[FunctionDefinition, str]],
+                Mapping[Tuple[PartId, Optional[FunctionDefinition]],
+                        Mapping[ArrayOrNames, _DotNodeInfo]]
+                ]:
+    part_id_to_func_to_id: Dict[PartId, Dict[FunctionDefinition, str]] = {}
+    part_id_func_to_node_info: Dict[Tuple[PartId, Optional[FunctionDefinition]],
+                     Dict[ArrayOrNames, _DotNodeInfo]] = {}
+
+    for part in partition.parts.values():
+        mapper = ArrayToDotNodeInfoMapper()
+        for out_name in part.output_names:
+            mapper(partition.name_to_output[out_name])
+
+        part_id_func_to_node_info[part.pid, None] = mapper.node_to_dot
+        part_id_to_func_to_id[part.pid] = {}
 
-    return emit.get()
+        # It is important that seen functions are emitted callee-first.
+        # (Otherwise function 'entry' nodes will get declared in the wrong
+        # cluster.) So use a data type that preserves order.
+        seen_functions: List[FunctionDefinition] = []
+
+        def gather_function_info(f: FunctionDefinition) -> None:
+            key = (part.pid, f)  # noqa: B023
+            if key in part_id_func_to_node_info:
+                return
+
+            mapper = ArrayToDotNodeInfoMapper()
+            for elem in f.returns.values():
+                mapper(elem)
+
+            part_id_func_to_node_info[key] = mapper.node_to_dot
+
+            for subfunc in mapper.functions:
+                gather_function_info(subfunc)
+
+            if f not in seen_functions:  # noqa: B023
+                seen_functions.append(f)  # noqa: B023
+
+        for f in mapper.functions:
+            gather_function_info(f)
+
+        # Again, important to preserve function order. Here we're relying
+        # on dicts to preserve order.
+        for f in seen_functions:
+            func_name = _get_function_name(f)
+            if func_name is not None:
+                fid = id_gen(dot_escape(func_name))
+            else:
+                fid = id_gen("func")
+
+            part_id_to_func_to_id.setdefault(part.pid, {})[f] = fid
+
+    return part_id_to_func_to_id, part_id_func_to_node_info
+
+# }}}
+
+
+def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str:
+    r"""Return a string in the `dot <https://graphviz.org>`_ language depicting the
+    graph of the computation of *result*.
+
+    :arg result: Outputs of the computation (cf.
+        :func:`pytato.generate_loopy`).
+    """
+
+    outputs: DictOfNamedArrays = normalize_outputs(result)
+
+    return get_dot_graph_from_partition(
+            DistributedGraphPartition(
+                parts={
+                    None: DistributedGraphPart(
+                        pid=None,
+                        needed_pids=frozenset(),
+                        user_input_names=frozenset(
+                            expr.name
+                            for expr in InputGatherer()(outputs)
+                            if isinstance(expr, Placeholder)
+                            ),
+                        partition_input_names=frozenset(),
+                        output_names=frozenset(outputs.keys()),
+                        name_to_recv_node={},
+                        name_to_send_nodes={},
+                        )
+                    },
+                name_to_output=outputs._data,
+                overall_output_names=tuple(outputs),
+                ))
 
 
 def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
@@ -386,170 +566,226 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
 
     :arg partition: Outputs of :func:`~pytato.find_distributed_partition`.
     """
-    # Maps each partition to a dict of its arrays with the node info
-    part_id_to_node_info: Dict[Hashable, Dict[ArrayOrNames, DotNodeInfo]] = {}
+    id_gen = UniqueNameGenerator()
 
-    for part in partition.parts.values():
-        mapper = ArrayToDotNodeInfoMapper()
-        for out_name in part.output_names:
-            mapper(partition.name_to_output[out_name])
+    # {{{ gather up node info, per partition and per function
 
-        part_id_to_node_info[part.pid] = mapper.nodes
+    # The "None" function is the body of the partition.
 
-    id_gen = UniqueNameGenerator()
+    part_id_to_func_to_id, part_id_func_to_node_info = \
+            _gather_partition_node_information(id_gen, partition)
 
-    emit = DotEmitter()
+    # }}}
+
+    emitter = DotEmitter()
+    emit_root = partial(emitter, ())
 
     emitted_placeholders = set()
 
-    with emit.block("digraph computation"):
-        emit("node [shape=rectangle]")
-        placeholder_to_id: Dict[ArrayOrNames, str] = {}
-        part_id_to_array_to_id: Dict[PartId, Dict[ArrayOrNames, str]] = {}
-
-        # First pass: generate names for all nodes
-        for part in partition.parts.values():
-            array_to_id = {}
-            for array, _ in part_id_to_node_info[part.pid].items():
-                if isinstance(array, Placeholder):
-                    # Placeholders are only emitted once
-                    if array in placeholder_to_id:
-                        node_id = placeholder_to_id[array]
-                    else:
-                        node_id = id_gen("array")
-                        placeholder_to_id[array] = node_id
+    emit_root("node [shape=rectangle]")
+
+    placeholder_to_id: Dict[ArrayOrNames, str] = {}
+    part_id_to_array_to_id: Dict[PartId, Dict[ArrayOrNames, str]] = {}
+
+    part_id_to_id = {pid: dot_escape(str(pid)) for pid in partition.parts}
+    assert len(set(part_id_to_id.values())) == len(partition.parts)
+
+    # {{{ generate names for all nodes in the root/None function
+
+    for part in partition.parts.values():
+        array_to_id = {}
+        for array in part_id_func_to_node_info[part.pid, None].keys():
+            if isinstance(array, Placeholder):
+                # Placeholders are only emitted once
+                if array in placeholder_to_id:
+                    node_id = placeholder_to_id[array]
                 else:
                     node_id = id_gen("array")
-                array_to_id[array] = node_id
-            part_id_to_array_to_id[part.pid] = array_to_id
-
-        # Second pass: emit the graph.
-        for part in partition.parts.values():
-            array_to_id = part_id_to_array_to_id[part.pid]
-
-            # {{{ emit receives nodes if distributed
-
-            if isinstance(part, DistributedGraphPart):
-                part_dist_recv_var_name_to_node_id = {}
-                for name, recv in (
-                        part.name_to_recv_node.items()):
-                    node_id = id_gen("recv")
-                    _emit_array(emit, "DistributedRecv", {
-                        "shape": stringify_shape(recv.shape),
-                        "dtype": str(recv.dtype),
-                        "src_rank": str(recv.src_rank),
-                        "comm_tag": str(recv.comm_tag),
-                        }, node_id)
+                    placeholder_to_id[array] = node_id
+            else:
+                node_id = id_gen("array")
+            array_to_id[array] = node_id
+        part_id_to_array_to_id[part.pid] = array_to_id
+
+    # }}}
+
+    # {{{ emit the graph
+
+    for part in partition.parts.values():
+        array_to_id = part_id_to_array_to_id[part.pid]
 
-                    part_dist_recv_var_name_to_node_id[name] = node_id
+        is_trivial_partition = part.pid is None and len(partition.parts) == 1
+        if is_trivial_partition:
+            part_subgraph_path: Tuple[str, ...] = ()
+        else:
+            part_subgraph_path = (f"cluster_{part_id_to_id[part.pid]}",)
+
+        emit_part = partial(emitter, part_subgraph_path)
+
+        if not is_trivial_partition:
+            emit_part("style=dashed")
+            emit_part(f'label="{part.pid}"')
+
+        # {{{ emit functions
+
+        # It is important that seen functions are emitted callee-first.
+        # Here we're relying on the part_id_to_func_to_id dict to preserve order.
+
+        for func, fid in part_id_to_func_to_id[part.pid].items():
+            func_subgraph_path = part_subgraph_path + (f"cluster_{fid}",)
+            label = _get_function_name(func) or fid
+
+            emitter(func_subgraph_path, f'label="{label}"')
+            emitter(func_subgraph_path, f'{fid} [label="{label}",shape="ellipse"]')
+
+            _emit_function(emitter, func_subgraph_path,
+                           id_gen, part_id_func_to_node_info[part.pid, func],
+                           part_id_to_func_to_id[part.pid],
+                           func.returns)
+
+        # }}}
+
+        # {{{ emit receives nodes
+
+        part_dist_recv_var_name_to_node_id = {}
+        for name, recv in (
+                part.name_to_recv_node.items()):
+            node_id = id_gen("recv")
+            _emit_array(emit_part, "DistributedRecv", {
+                "shape": stringify_shape(recv.shape),
+                "dtype": str(recv.dtype),
+                "src_rank": str(recv.src_rank),
+                "comm_tag": str(recv.comm_tag),
+                }, node_id)
+
+            part_dist_recv_var_name_to_node_id[name] = node_id
+
+        # }}}
+
+        part_node_to_info = part_id_func_to_node_info[part.pid, None]
+        input_arrays: List[Array] = []
+        internal_arrays: List[ArrayOrNames] = []
+
+        for array in part_node_to_info.keys():
+            if isinstance(array, InputArgumentBase):
+                input_arrays.append(array)
             else:
-                part_dist_recv_var_name_to_node_id = {}
+                internal_arrays.append(array)
 
-            # }}}
+        # {{{ emit inputs
 
-            part_node_to_info = part_id_to_node_info[part.pid]
-            input_arrays: List[Array] = []
-            internal_arrays: List[ArrayOrNames] = []
+        # Placeholders are unique, i.e. the same Placeholder object may be
+        # shared among partitions. Therefore, they should not live inside
+        # the (dot) subgraph, otherwise they would be forced into multiple
+        # subgraphs.
 
-            for array in part_node_to_info.keys():
-                if isinstance(array, InputArgumentBase):
-                    input_arrays.append(array)
+        for array in input_arrays:
+            if not isinstance(array, Placeholder):
+                _emit_array(emit_part,
+                            part_node_to_info[array].title,
+                            part_node_to_info[array].fields,
+                            array_to_id[array], "deepskyblue")
+            else:
+                # Is a Placeholder
+                if array in emitted_placeholders:
+                    continue
+
+                _emit_array(emit_root,
+                            part_node_to_info[array].title,
+                            part_node_to_info[array].fields,
+                            array_to_id[array], "deepskyblue")
+
+                # Emit cross-partition edges
+                if array.name in part_dist_recv_var_name_to_node_id:
+                    tgt = part_dist_recv_var_name_to_node_id[array.name]
+                    emit_root(f"{tgt} -> {array_to_id[array]} [style=dotted]")
+                    emitted_placeholders.add(array)
+                elif array.name in part.user_input_names:
+                    # no arrows for these
+                    pass
                 else:
-                    internal_arrays.append(array)
-
-            # {{{ emit inputs
-
-            # Placeholders are unique, i.e. the same Placeholder object may be
-            # shared among partitions. Therefore, they should not live inside
-            # the (dot) subgraph, otherwise they would be forced into multiple
-            # subgraphs.
-
-            for array in input_arrays:
-                # Non-Placeholders are emitted *inside* their subgraphs below.
-                if isinstance(array, Placeholder):
-                    if array not in emitted_placeholders:
-                        _emit_array(emit,
-                                    part_node_to_info[array].title,
-                                    part_node_to_info[array].fields,
-                                    array_to_id[array], "deepskyblue")
-
-                        # Emit cross-partition edges
-                        if array.name in part_dist_recv_var_name_to_node_id:
-                            tgt = part_dist_recv_var_name_to_node_id[array.name]
-                            emit(f"{tgt} -> {array_to_id[array]} [style=dotted]")
-                            emitted_placeholders.add(array)
-                        elif array.name in part.user_input_names:
-                            # These are placeholders for external input. They
-                            # are cleanly associated with a single partition
-                            # and thus emitted below.
-                            pass
-                        else:
-                            # placeholder for a value from a different partition
-                            computing_pid = None
-                            for other_part in partition.parts.values():
-                                if array.name in other_part.output_names:
-                                    computing_pid = other_part.pid
-                                    break
-                            assert computing_pid is not None
-                            tgt = part_id_to_array_to_id[computing_pid][
-                                    partition.name_to_output[array.name]]
-                            emit(f"{tgt} -> {array_to_id[array]} [style=dashed]")
-                            emitted_placeholders.add(array)
-
-            # }}}
-
-            with emit.block(f'subgraph "cluster_part_{part.pid}"'):
-                emit("style=dashed")
-                emit(f'label="{part.pid}"')
-
-                for array in input_arrays:
-                    if (not isinstance(array, Placeholder)
-                            or array.name in part.user_input_names):
-                        _emit_array(emit,
-                                    part_node_to_info[array].title,
-                                    part_node_to_info[array].fields,
-                                    array_to_id[array], "deepskyblue")
-
-                # Emit internal nodes
-                for array in internal_arrays:
-                    _emit_array(emit,
-                                part_node_to_info[array].title,
-                                part_node_to_info[array].fields,
-                                array_to_id[array])
-
-                # {{{ emit send nodes if distributed
-
-                deferred_send_edges = []
-                if isinstance(part, DistributedGraphPart):
-                    for name, sends in (
-                            part.name_to_send_nodes.items()):
-                        for send in sends:
-                            node_id = id_gen("send")
-                            _emit_array(emit, "DistributedSend", {
-                                "dest_rank": str(send.dest_rank),
-                                "comm_tag": str(send.comm_tag),
-                                }, node_id)
-
-                            deferred_send_edges.append(
-                                    f"{array_to_id[send.data]} -> {node_id}"
-                                    f'[style=dotted, label="{dot_escape(name)}"]')
-
-                # }}}
-
-            # If an edge is emitted in a subgraph, it drags its nodes into the
-            # subgraph, too. Not what we want.
-            for edge in deferred_send_edges:
-                emit(edge)
-
-            # Emit intra-partition edges
-            for array, node in part_node_to_info.items():
-                for label, tail_array in node.edges.items():
-                    tail = array_to_id[tail_array]
-                    head = array_to_id[array]
-                    emit('%s -> %s [label="%s"]' %
-                        (tail, head, dot_escape(label)))
-
-    return emit.get()
+                    # placeholder for a value from a different partition
+                    computing_pid = None
+                    for other_part in partition.parts.values():
+                        if array.name in other_part.output_names:
+                            computing_pid = other_part.pid
+                            break
+                    assert computing_pid is not None
+                    tgt = part_id_to_array_to_id[computing_pid][
+                            partition.name_to_output[array.name]]
+                    emit_root(f"{tgt} -> {array_to_id[array]} [style=dashed]")
+                    emitted_placeholders.add(array)
+
+        # }}}
+
+        # Emit internal nodes
+        for array in internal_arrays:
+            _emit_array(emit_part,
+                        part_node_to_info[array].title,
+                        part_node_to_info[array].fields,
+                        array_to_id[array])
+
+        # {{{ emit send nodes if distributed
+
+        if isinstance(part, DistributedGraphPart):
+            for name, sends in part.name_to_send_nodes.items():
+                for send in sends:
+                    node_id = id_gen("send")
+                    _emit_array(emit_part, "DistributedSend", {
+                        "dest_rank": str(send.dest_rank),
+                        "comm_tag": str(send.comm_tag),
+                        }, node_id)
+
+                    # If an edge is emitted in a subgraph, it drags its
+                    # nodes into the subgraph, too. Not what we want.
+                    emit_root(
+                            f"{array_to_id[send.data]} -> {node_id}"
+                            f'[style=dotted, label="{dot_escape(name)}"]')
+
+        # }}}
+
+        # Emit intra-partition edges
+        for array, node in part_node_to_info.items():
+            for label, tail_item in node.edges.items():
+                head = array_to_id[array]
+
+                if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)):
+                    tail = array_to_id[tail_item]
+                elif isinstance(tail_item, FunctionDefinition):
+                    tail = part_id_to_func_to_id[part.pid][tail_item]
+                else:
+                    raise ValueError(
+                            f"unexpected type of tail on edge: {type(tail_item)}")
+
+                emit_root('%s -> %s [label="%s"]' %
+                    (tail, head, dot_escape(label)))
+
+        _emit_name_cluster(
+                emitter, part_subgraph_path,
+                {name: partition.name_to_output[name] for name in part.output_names},
+                array_to_id, id_gen, "Part outputs")
+
+    # }}}
+
+    # Arrays may occur in multiple partitions, they get drawn separately anyhow
+    # (unless they're Placeholders). Don't be tempted to use
+    # combined_array_to_id everywhere.
+
+    # {{{ draw overall outputs
+
+    combined_array_to_id: Dict[ArrayOrNames, str] = {}
+    for part_id in partition.parts.keys():
+        combined_array_to_id.update(part_id_to_array_to_id[part_id])
+
+    _emit_name_cluster(
+            emitter, (),
+            {name: partition.name_to_output[name]
+             for name in partition.overall_output_names},
+            combined_array_to_id, id_gen, "Overall outputs")
+
+    # }}}
+
+    return emitter.generate()
 
 
 def show_dot_graph(result: Union[str, Array, DictOfNamedArrays,
diff --git a/test/test_codegen.py b/test/test_codegen.py
index e02fbc85c17ea1ce086d690526747b872350dba7..b6d0bba8b4eccde3f9cee9a81c98ab9eaf14d8a5 100755
--- a/test/test_codegen.py
+++ b/test/test_codegen.py
@@ -1854,15 +1854,16 @@ def test_pad(ctx_factory):
         np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array)
 
 
-def test_function_call(ctx_factory):
+def test_function_call(ctx_factory, visualize=False):
+    from functools import partial
     cl_ctx = ctx_factory()
     cq = cl.CommandQueue(cl_ctx)
 
     def f(x):
         return 2*x
 
-    def g(x):
-        return 2*x, 3*x
+    def g(tracer, x):
+        return tracer(f, x), 3*x
 
     def h(x, y):
         return {"twice": 2*x+y, "thrice": 3*x+y}
@@ -1870,7 +1871,7 @@ def test_function_call(ctx_factory):
     def build_expression(tracer):
         x = pt.arange(500, dtype=np.float32)
         twice_x = tracer(f, x)
-        twice_x_2, thrice_x_2 = tracer(g, x)
+        twice_x_2, thrice_x_2 = tracer(partial(g, tracer), x)
 
         result = tracer(h, x, 2*x)
         twice_x_3 = result["twice"]
@@ -1881,13 +1882,19 @@ def test_function_call(ctx_factory):
                 "baz": 65 * twice_x,
                 "quux": 7 * twice_x_2}
 
-    result1 = pt.tag_all_calls_to_be_inlined(
+    result_with_functions = pt.tag_all_calls_to_be_inlined(
         pt.make_dict_of_named_arrays(build_expression(pt.trace_call)))
-    result2 = pt.make_dict_of_named_arrays(
+    result_without_functions = pt.make_dict_of_named_arrays(
         build_expression(lambda fn, *args: fn(*args)))
 
-    _, outputs = pt.generate_loopy(result1)(cq, out_host=True)
-    _, expected = pt.generate_loopy(result2)(cq, out_host=True)
+    # test that visualizing graphs with functions works
+    dot = pt.get_dot_graph(result_with_functions)
+
+    if visualize:
+        pt.show_dot_graph(dot)
+
+    _, outputs = pt.generate_loopy(result_with_functions)(cq, out_host=True)
+    _, expected = pt.generate_loopy(result_without_functions)(cq, out_host=True)
 
     assert len(outputs) == len(expected)