From 6636a0ff6d720e2e7025e108b484b16910edf9cc Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 3 May 2023 17:50:53 -0500 Subject: [PATCH] New distributed-memory DAG partitioner Co-authored-by: Matt Smith --- doc/dag.rst | 10 - doc/design.rst | 5 - doc/distributed.rst | 5 + doc/index.rst | 1 + examples/distributed.py | 38 +- examples/partition.py | 65 -- pytato/__init__.py | 8 +- pytato/distributed/__init__.py | 4 - pytato/distributed/execute.py | 55 +- pytato/distributed/nodes.py | 14 +- pytato/distributed/partition.py | 1335 ++++++++++++++++++------------- pytato/distributed/tags.py | 21 +- pytato/distributed/verify.py | 199 +++-- pytato/partition.py | 474 ----------- pytato/transform/__init__.py | 175 ---- pytato/visualization.py | 44 +- test/test_codegen.py | 67 -- test/test_distributed.py | 141 ++-- 18 files changed, 1098 insertions(+), 1563 deletions(-) create mode 100644 doc/distributed.rst delete mode 100644 examples/partition.py delete mode 100644 pytato/partition.py diff --git a/doc/dag.rst b/doc/dag.rst index 80a7c33..68207f1 100644 --- a/doc/dag.rst +++ b/doc/dag.rst @@ -25,18 +25,8 @@ Stringifying Expression Graphs .. _partitioning: -Partitioning Array Expression Graphs -==================================== - -.. automodule:: pytato.partition - .. _distributed: -Support for Distributed-Memory/Message Passing -============================================== - -.. automodule:: pytato.distributed - Utilities and Diagnostics ========================= diff --git a/doc/design.rst b/doc/design.rst index 5eaee80..93a4ebd 100644 --- a/doc/design.rst +++ b/doc/design.rst @@ -154,11 +154,6 @@ Reserved Identifiers names of :class:`~pytato.array.DataWrapper` arguments that are not supplied by the user. - - ``_pt_part_ph``: Used to automatically generate identifiers for - names of :class:`~pytato.array.Placeholder` that represent data - transport across parts of a partitioned DAG - (cf. :func:`~pytato.partition.find_partition`). - - ``_pt_dist``: Used to automatically generate identifiers for placeholders at distributed communication boundaries. diff --git a/doc/distributed.rst b/doc/distributed.rst new file mode 100644 index 0000000..69d0728 --- /dev/null +++ b/doc/distributed.rst @@ -0,0 +1,5 @@ +Support for Distributed-Memory/Message Passing +============================================== + +.. automodule:: pytato.distributed + diff --git a/doc/index.rst b/doc/index.rst index 81f1c43..5a08477 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -19,6 +19,7 @@ Here's an example usage: array dag + distributed codegen internal design diff --git a/examples/distributed.py b/examples/distributed.py index b763b71..3301f96 100644 --- a/examples/distributed.py +++ b/examples/distributed.py @@ -21,16 +21,24 @@ def main(): x_in = rng.integers(100, size=(4, 4)) x = pt.make_data_wrapper(x_in) - mytag = (main, "x") - halo = staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=mytag, + mytag_x = (main, "x") + x_plus = staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=mytag_x, stapled_to=make_distributed_recv( - src_rank=(rank+1) % size, comm_tag=mytag, shape=(4, 4), dtype=int)) + src_rank=(rank+1) % size, comm_tag=mytag_x, shape=(4, 4), dtype=int)) - y = x+halo + y = x+x_plus + + mytag_y = (main, "y") + y_plus = staple_distributed_send(y, dest_rank=(rank-1) % size, comm_tag=mytag_y, + stapled_to=make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=mytag_y, shape=(4, 4), dtype=int)) + + z = y+y_plus # Find the partition - outputs = pt.make_dict_of_named_arrays({"out": y}) - distributed_parts = find_distributed_partition(outputs) + outputs = pt.make_dict_of_named_arrays({"out": z}) + distributed_parts = find_distributed_partition(comm, outputs) + distributed_parts, _ = number_distributed_tags( comm, distributed_parts, base_tag=42) prg_per_partition = generate_code_for_partition(distributed_parts) @@ -39,23 +47,31 @@ def main(): from pytato.visualization import show_dot_graph show_dot_graph(distributed_parts) - # Sanity check - from pytato.visualization import get_dot_graph_from_partition - get_dot_graph_from_partition(distributed_parts) + if 0: + # Sanity check + from pytato.visualization import get_dot_graph_from_partition + get_dot_graph_from_partition(distributed_parts) # Execute the distributed partition ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) + pt.verify_distributed_partition(comm, distributed_parts) + context = execute_distributed_partition(distributed_parts, prg_per_partition, queue, comm) final_res = context["out"].get(queue) comm.isend(x_in, dest=(rank-1) % size, tag=42) - ref_halo = comm.recv(source=(rank+1) % size, tag=42) + ref_x_plus = comm.recv(source=(rank+1) % size, tag=42) + + ref_y_in = x_in + ref_x_plus + + comm.isend(ref_y_in, dest=(rank-1) % size, tag=43) + ref_y_plus = comm.recv(source=(rank+1) % size, tag=43) - ref_res = x_in + ref_halo + ref_res = ref_y_in + ref_y_plus np.testing.assert_allclose(ref_res, final_res) diff --git a/examples/partition.py b/examples/partition.py deleted file mode 100644 index 422ac49..0000000 --- a/examples/partition.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python - -import pytato as pt -import pyopencl as cl -import numpy as np -from pytato.partition import (execute_partition, - generate_code_for_partition, find_partition) - -from pytato.transform import TopoSortMapper - -from dataclasses import dataclass - - -@dataclass(frozen=True, eq=True) -class MyPartitionId(): - num: int - - -def get_partition_id(topo_list, expr) -> MyPartitionId: - # Partition nodes into groups of two - res = MyPartitionId(topo_list.index(expr)//2) - return res - - -def main(): - x_in = np.random.randn(2, 2) - x = pt.make_data_wrapper(x_in) - y = pt.stack([x@x.T, 2*x, 42+x]) - y = y + 55 - - tm = TopoSortMapper() - tm(y) - - from functools import partial - pfunc = partial(get_partition_id, tm.topological_order) - - # Find the partitions - outputs = pt.DictOfNamedArrays({"out": y}) - partition = find_partition(outputs, pfunc) - - # Show the partitions - from pytato.visualization import get_dot_graph_from_partition - get_dot_graph_from_partition(partition) - - # Execute the partitions - ctx = cl.create_some_context() - queue = cl.CommandQueue(ctx) - - prg_per_partition = generate_code_for_partition(partition) - - context = execute_partition(partition, prg_per_partition, queue) - - final_res = [context[k] for k in outputs.keys()] - - # Execute the unpartitioned code for comparison - prg = pt.generate_loopy(y) - _, (out, ) = prg(queue) - - np.testing.assert_allclose([out], final_res) - - print("Partitioning test succeeded.") - - -if __name__ == "__main__": - main() diff --git a/pytato/__init__.py b/pytato/__init__.py index afb0865..731c9b0 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -101,16 +101,15 @@ from pytato.distributed.nodes import (make_distributed_send, make_distributed_re from pytato.distributed.partition import ( find_distributed_partition, DistributedGraphPart, DistributedGraphPartition) from pytato.distributed.tags import number_distributed_tags +from pytato.distributed.execute import ( + generate_code_for_partition, execute_distributed_partition) from pytato.distributed.verify import verify_distributed_partition -from pytato.distributed.execute import execute_distributed_partition from pytato.transform.lower_to_index_lambda import to_index_lambda from pytato.transform.remove_broadcasts_einsum import ( rewrite_einsums_with_no_broadcasts) from pytato.transform.metadata import unify_axes_tags -from pytato.partition import generate_code_for_partition - __all__ = ( "dtype", @@ -162,11 +161,10 @@ __all__ = ( "find_distributed_partition", "number_distributed_tags", + "generate_code_for_partition", "execute_distributed_partition", "verify_distributed_partition", - "generate_code_for_partition", - "to_index_lambda", "rewrite_einsums_with_no_broadcasts", diff --git a/pytato/distributed/__init__.py b/pytato/distributed/__init__.py index cee8854..04b368a 100644 --- a/pytato/distributed/__init__.py +++ b/pytato/distributed/__init__.py @@ -9,10 +9,6 @@ outputs as inputs". That sounds obvious, but in the distributed-memory case, this is harder to decide than it looks, since we do not have full knowledge of the computation graph. Edges go off to other nodes and then come back. -As a first step towards making this tractable, we currently strengthen the -requirement to create partition boundaries on every edge that goes between -nodes that are/are not a dependency of a receive or that feed/do not feed a send. - .. automodule:: pytato.distributed.nodes .. automodule:: pytato.distributed.partition .. automodule:: pytato.distributed.verify diff --git a/pytato/distributed/execute.py b/pytato/distributed/execute.py index bf464dc..79444cd 100644 --- a/pytato/distributed/execute.py +++ b/pytato/distributed/execute.py @@ -1,4 +1,6 @@ """ +Execution +--------- .. currentmodule:: pytato .. autofunction:: execute_distributed_partition @@ -30,9 +32,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Dict, Hashable, Tuple, Optional, TYPE_CHECKING +from typing import Any, Dict, Hashable, Tuple, Optional, TYPE_CHECKING, Mapping +from pytato.array import make_dict_of_named_arrays from pytato.target import BoundProgram from pytato.scalar_expr import INT_CLASSES @@ -42,7 +45,7 @@ import numpy as np from pytato.distributed.nodes import ( DistributedRecv, DistributedSend) from pytato.distributed.partition import ( - DistributedGraphPartition, DistributedGraphPart) + DistributedGraphPartition, DistributedGraphPart, PartId) import logging logger = logging.getLogger(__name__) @@ -52,6 +55,28 @@ if TYPE_CHECKING: import mpi4py.MPI +# {{{ generate_code_for_partition + +def generate_code_for_partition(partition: DistributedGraphPartition) \ + -> Mapping[PartId, BoundProgram]: + """Return a mapping of partition identifiers to their + :class:`pytato.target.BoundProgram`.""" + from pytato import generate_loopy + part_id_to_prg = {} + + for part in sorted(partition.parts.values(), + key=lambda part_: sorted(part_.output_names)): + d = make_dict_of_named_arrays( + {var_name: partition.name_to_output[var_name] + for var_name in part.output_names + }) + part_id_to_prg[part.pid] = generate_loopy(d) + + return part_id_to_prg + +# }}} + + # {{{ distributed execute def _post_receive(mpi_communicator: mpi4py.MPI.Comm, @@ -88,11 +113,11 @@ def execute_distributed_partition( from mpi4py import MPI - if len(partition.parts) != 1: + if any(part.name_to_recv_node for part in partition.parts.values()): recv_names_tup, recv_requests_tup, recv_buffers_tup = zip(*[ (name,) + _post_receive(mpi_communicator, recv) for part in partition.parts.values() - for name, recv in part.input_name_to_recv_node.items()]) + for name, recv in part.name_to_recv_node.items()]) recv_names = list(recv_names_tup) recv_requests = list(recv_requests_tup) recv_buffers = list(recv_buffers_tup) @@ -100,7 +125,6 @@ def execute_distributed_partition( del recv_requests_tup del recv_buffers_tup else: - # Only a single partition, no recv requests exist recv_names = [] recv_requests = [] recv_buffers = [] @@ -146,13 +170,14 @@ def execute_distributed_partition( context.update(result_dict) - for name, send_node in part.output_name_to_send_node.items(): - # FIXME: pytato shouldn't depend on pyopencl - if isinstance(context[name], np.ndarray): - data = context[name] - else: - data = context[name].get(queue) - send_requests.append(_mpi_send(mpi_communicator, send_node, data)) + for name, send_nodes in part.name_to_send_nodes.items(): + for send_node in send_nodes: + # FIXME: pytato shouldn't depend on pyopencl + if isinstance(context[name], np.ndarray): + data = context[name] + else: + data = context[name].get(queue) + send_requests.append(_mpi_send(mpi_communicator, send_node, data)) pids_executed.add(part.pid) pids_to_execute.remove(part.pid) @@ -171,8 +196,8 @@ def execute_distributed_partition( buf = recv_buffers.pop(idx) # FIXME: pytato shouldn't depend on pyopencl - import pyopencl as cl - context[name] = cl.array.to_device(queue, buf, allocator=allocator) + import pyopencl.array as cl_array + context[name] = cl_array.to_device(queue, buf, allocator=allocator) recv_names_completed.add(name) # {{{ main loop @@ -182,7 +207,7 @@ def execute_distributed_partition( for pid in pids_to_execute # FIXME: Only O(n**2) altogether. Nobody is going to notice, right? if partition.parts[pid].needed_pids <= pids_executed - and (set(partition.parts[pid].input_name_to_recv_node) + and (set(partition.parts[pid].name_to_recv_node) <= recv_names_completed)} for pid in ready_pids: part = partition.parts[pid] diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index 856b3f8..3dc5721 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -1,12 +1,22 @@ """ +Nodes +----- +The following nodes represent communication in the DAG: + .. currentmodule:: pytato .. autoclass:: DistributedSend .. autoclass:: DistributedSendRefHolder .. autoclass:: DistributedRecv -.. autofunction:: make_distributed_send +These functions aid in creating communication nodes: + .. autofunction:: staple_distributed_send .. autofunction:: make_distributed_recv + +For completeness, individual (non-held/"stapled") :class:`DistributedSend` nodes +can be made via this function: + +.. autofunction:: make_distributed_send """ from __future__ import annotations @@ -53,6 +63,8 @@ CommTagType = Hashable class DistributedSend(Taggable): """Class representing a distributed send operation. + See :class:`DistributedSendRefHolder` for a way to ensure that nodes + of this type remain part of a DAG. .. attribute:: data diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 5cfa069..a59bd0e 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -1,10 +1,38 @@ """ +Partitioning +------------ + +Partitioning of graphs in :mod:`pytato` serves to enable +:ref:`distributed computation `, i.e. sending and receiving data +as part of graph evaluation. + +Partitioning of expression graphs is based on a few assumptions: + +- We must be able to execute parts in any dependency-respecting order. +- Parts are compiled at partitioning time, so what inputs they take from memory + vs. what they compute is decided at that time. +- No part may depend on its own outputs as inputs. + .. currentmodule:: pytato .. autoclass:: DistributedGraphPart .. autoclass:: DistributedGraphPartition .. autofunction:: find_distributed_partition + +.. currentmodule:: pytato.distributed.partition + +Internal stuff that is only here because the documentation tool wants it +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. class:: T + + A type variable for :class:`~pytato.array.AbstractResultWithNamedArrays`. +.. autoclass:: CommunicationOpIdentifier +.. class:: CommunicationDepGraph + + An alias for + ``Mapping[CommunicationOpIdentifier, AbstractSet[CommunicationOpIdentifier]]``. """ from __future__ import annotations @@ -33,119 +61,238 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from functools import reduce +import collections from typing import ( - Tuple, Any, Mapping, FrozenSet, Set, Dict, cast, Iterable, Callable, List) -from functools import cached_property + Iterator, Iterable, Sequence, Any, Mapping, FrozenSet, Set, Dict, cast, + List, AbstractSet, TypeVar, TYPE_CHECKING, Hashable, Optional) import attrs from immutables import Map +from pytools.graph import CycleError +from pytools import memoize_method + from pymbolic.mapper.optimize import optimize_mapper from pytools import UniqueNameGenerator -from pytools.tag import UniqueTag from pytato.scalar_expr import SCALAR_CLASSES -from pytato.array import (Array, - DictOfNamedArrays, Placeholder, make_placeholder, - NamedArray) -from pytato.transform import (ArrayOrNames, CopyMapper, Mapper, - CachedWalkMapper, CopyMapperWithExtraArgs, - CombineMapper, CopyMapperResultT) -from pytato.partition import GraphPart, GraphPartition, PartId, GraphPartitioner +from pytato.array import (Array, DictOfNamedArrays, Placeholder, make_placeholder) +from pytato.transform import (ArrayOrNames, CopyMapper, + CachedWalkMapper, + CombineMapper) from pytato.distributed.nodes import ( DistributedRecv, DistributedSend, DistributedSendRefHolder) +from pytato.distributed.nodes import CommTagType from pytato.analysis import DirectPredecessorsGetter +if TYPE_CHECKING: + import mpi4py.MPI + + +@attrs.define(frozen=True) +class CommunicationOpIdentifier: + """Identifies a communication operation (consisting of a pair of + a send and a receive). + + .. attribute:: src_rank + .. attribute:: dest_rank + .. attribute:: comm_tag + + .. note:: + + In :func:`~pytato.find_distributed_partition`, we use instances of this + type as though they identify sends or receives, i.e. just a single end + of the communication. Realize that this is only true given the + additional context of which rank is the local rank. + """ + src_rank: int + dest_rank: int + comm_tag: CommTagType + + +CommunicationDepGraph = Mapping[ + CommunicationOpIdentifier, AbstractSet[CommunicationOpIdentifier]] + + +_KeyT = TypeVar("_KeyT") +_ValueT = TypeVar("_ValueT") + + +# {{{ crude ordered set + + +class _OrderedSet(collections.abc.MutableSet[_ValueT]): + def __init__(self, items: Optional[Iterable[_ValueT]] = None): + # Could probably also use a valueless dictionary; not sure if it matters + self._items: Set[_ValueT] = set() + self._items_ordered: List[_ValueT] = [] + if items is not None: + for item in items: + self.add(item) + + def add(self, item: _ValueT) -> None: + if item not in self._items: + self._items.add(item) + self._items_ordered.append(item) + + def discard(self, item: _ValueT) -> None: + # Not currently needed + raise NotImplementedError + + def __len__(self) -> int: + return len(self._items) + + def __iter__(self) -> Iterator[_ValueT]: + return iter(self._items_ordered) + + def __contains__(self, item: Any) -> bool: + return item in self._items + + def __and__(self, other: AbstractSet[_ValueT]) -> _OrderedSet[_ValueT]: + result: _OrderedSet[_ValueT] = _OrderedSet() + for item in self._items_ordered: + if item in other: + result.add(item) + return result + + # Must be "Any" instead of "_ValueT", otherwise it violates Liskov substitution + # according to mypy. *shrug* + def __or__(self, other: AbstractSet[Any]) -> _OrderedSet[_ValueT]: + result: _OrderedSet[_ValueT] = _OrderedSet(self._items_ordered) + for item in other: + result.add(item) + return result + + def __sub__(self, other: AbstractSet[_ValueT]) -> _OrderedSet[_ValueT]: + result: _OrderedSet[_ValueT] = _OrderedSet() + for item in self._items_ordered: + if item not in other: + result.add(item) + return result + +# }}} + + +# {{{ distributed graph part + +PartId = Hashable -# {{{ distributed graph partition @attrs.define(frozen=True, slots=False) -class DistributedGraphPart(GraphPart): - """For one graph partition, record send/receive information for input/ +class DistributedGraphPart: + """For one graph part, record send/receive information for input/ output names. - .. attribute:: input_name_to_recv_node - .. attribute:: output_name_to_send_node - .. attribute:: distributed_sends + Names that occur as keys in :attr:`name_to_recv_node` and + :attr:`name_to_send_nodes` are usable as input names by other + parts, or in the result of the computation. + + - Names specified in :attr:`name_to_recv_node` *must not* occur in + :attr:`output_names`. + - Names specified in :attr:`name_to_send_nodes` *must* occur in + :attr:`output_names`. + + .. attribute:: pid + + An identifier for this part of the graph. + + .. attribute:: needed_pids + + The IDs of parts that are required to be evaluated before this + part can be evaluated. + + .. attribute:: user_input_names + + A :class:`frozenset` of names representing input to the computational + graph, i.e. which were *not* introduced by partitioning. + + .. attribute:: partition_input_names + + A :class:`frozenset` of names of placeholders the part requires as + input from other parts in the partition. + + .. attribute:: output_names + + Names of placeholders this part provides as output. + + .. attribute:: name_to_recv_node + .. attribute:: name_to_send_nodes + + .. automethod:: all_input_names """ - input_name_to_recv_node: Dict[str, DistributedRecv] - output_name_to_send_node: Dict[str, DistributedSend] - distributed_sends: List[DistributedSend] + pid: PartId + needed_pids: FrozenSet[PartId] + user_input_names: FrozenSet[str] + partition_input_names: FrozenSet[str] + output_names: FrozenSet[str] + name_to_recv_node: Mapping[str, DistributedRecv] + name_to_send_nodes: Mapping[str, Sequence[DistributedSend]] + + @memoize_method + def all_input_names(self) -> FrozenSet[str]: + return self.user_input_names | self. partition_input_names + +# }}} + + +# {{{ distributed graph partition @attrs.define(frozen=True, slots=False) -class DistributedGraphPartition(GraphPartition): - """Store information about distributed graph partitions. This - has the same attributes as :class:`~pytato.partition.GraphPartition`, - however :attr:`~pytato.partition.GraphPartition.parts` now maps to - instances of :class:`DistributedGraphPart`. +class DistributedGraphPartition: + """ + .. attribute:: parts + + Mapping from part IDs to instances of :class:`DistributedGraphPart`. + + .. attribute:: name_to_output + + Mapping of placeholder names to the respective :class:`pytato.array.Array` + they represent. """ - parts: Dict[PartId, DistributedGraphPart] + parts: Mapping[PartId, DistributedGraphPart] + name_to_output: Mapping[str, Array] # }}} -# {{{ _partition_to_distributed_partition +# {{{ _DistributedInputReplacer -def _map_distributed_graph_partition_nodes( - map_array: Callable[[Array], Array], - map_send: Callable[[DistributedSend], DistributedSend], - gp: DistributedGraphPartition) -> DistributedGraphPartition: - """Return a new copy of *gp* with all :class:`~pytato.Array` instances - mapped by *map_array* and all :class:`DistributedSend` instances mapped - by *map_send*. - """ - from attrs import evolve as replace - - return replace( - gp, - var_name_to_result={name: map_array(ary) - for name, ary in gp.var_name_to_result.items()}, - parts={ - pid: replace(part, - input_name_to_recv_node={ - in_name: cast(DistributedRecv, map_array(recv)) - for in_name, recv in part.input_name_to_recv_node.items()}, - output_name_to_send_node={ - out_name: map_send(send) - for out_name, send in part.output_name_to_send_node.items()}, - distributed_sends=[ - map_send(send) for send in part.distributed_sends] - ) - for pid, part in gp.parts.items() - }) - - -class _DistributedCommReplacer(CopyMapper): - """Mapper to process a DAG for realization of :class:`DistributedSend` - and :class:`DistributedRecv` outside of normal code generation. - - - Replaces :class:`DistributedRecv` with :class`~pytato.Placeholder` - so that received data can be externally supplied, making a note - in :attr:`input_name_to_recv_node`. - - - Makes note of data to be sent from :class:`DistributedSend` nodes - in :attr:`output_name_to_send_node`. +class _DistributedInputReplacer(CopyMapper): + """Replaces part inputs with :class:`~pytato.array.Placeholder` + instances for their assigned names. Also gathers names for + user-supplied inputs needed by the part """ - def __init__(self, dist_name_generator: UniqueNameGenerator) -> None: + def __init__(self, + recvd_ary_to_name: Mapping[Array, str], + sptpo_ary_to_name: Mapping[Array, str], + name_to_output: Mapping[str, Array], + ) -> None: super().__init__() - self.name_generator = dist_name_generator + self.recvd_ary_to_name = recvd_ary_to_name + self.sptpo_ary_to_name = sptpo_ary_to_name + self.name_to_output = name_to_output + self.output_arrays = frozenset(name_to_output.values()) - self.input_name_to_recv_node: Dict[str, DistributedRecv] = {} - self.output_name_to_send_node: Dict[str, DistributedSend] = {} + self.user_input_names: Set[str] = set() + self.partition_input_name_to_placeholder: Dict[str, Placeholder] = {} + + def map_placeholder(self, expr: Placeholder) -> Placeholder: + self.user_input_names.add(expr.name) + return expr def map_distributed_recv(self, expr: DistributedRecv) -> Placeholder: - new_name = self.name_generator() - self.input_name_to_recv_node[new_name] = expr - return make_placeholder(new_name, self.rec_idx_or_size_tuple(expr.shape), - expr.dtype, tags=expr.tags, axes=expr.axes) + name = self.recvd_ary_to_name[expr] + return self._get_placeholder_for(name, expr) def map_distributed_send_ref_holder( self, expr: DistributedSendRefHolder) -> Array: - raise ValueError("DistributedSendRefHolder should not occur in partitioned " - "graphs") + result = self.rec(expr.passthrough_data) + assert isinstance(result, Array) + return result # Note: map_distributed_send() is not called like other mapped methods in a # DAG traversal, since a DistributedSend is not an Array and has no @@ -154,627 +301,683 @@ class _DistributedCommReplacer(CopyMapper): # are no more references to the DistributedSends from within the DAG. This # method must therefore be called explicitly. def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: + new_data = self.rec(expr.data) + assert isinstance(new_data, Array) new_send = DistributedSend( - data=self.rec(expr.data), + data=new_data, dest_rank=expr.dest_rank, comm_tag=expr.comm_tag, tags=expr.tags) + return new_send - new_name = self.name_generator() - self.output_name_to_send_node[new_name] = new_send + # type ignore because no args, kwargs + def rec(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override] + assert isinstance(expr, Array) - return new_send + key = self.get_cache_key(expr) + try: + return self._cache[key] + except KeyError: + pass + + # If the array is an output from the current part, it would + # be counterproductive to turn it into a placeholder: we're + # the ones who are supposed to compute it! + if expr not in self.output_arrays: + + name = self.recvd_ary_to_name.get(expr) + if name is not None: + return self._get_placeholder_for(name, expr) + + name = self.sptpo_ary_to_name.get(expr) + if name is not None: + return self._get_placeholder_for(name, expr) + return cast(ArrayOrNames, super().rec(expr)) -def _partition_to_distributed_partition(partition: GraphPartition, - pid_to_distributed_sends: Dict[PartId, List[DistributedSend]]) -> \ - DistributedGraphPartition: - var_name_to_result = {} + def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: + placeholder = self.partition_input_name_to_placeholder.get(name) + if placeholder is None: + placeholder = make_placeholder( + name, expr.shape, expr.dtype, expr.tags, + expr.axes) + self.partition_input_name_to_placeholder[name] = placeholder + return placeholder + +# }}} + + +@attrs.define(frozen=True) +class _PartCommIDs: + """A *part*, unlike a *batch*, begins with receives and ends with sends. + """ + recv_ids: FrozenSet[CommunicationOpIdentifier] + send_ids: FrozenSet[CommunicationOpIdentifier] + + +# {{{ _make_distributed_partition + +def _make_distributed_partition( + name_to_output_per_part: Sequence[Mapping[str, Array]], + part_comm_ids: Sequence[_PartCommIDs], + recvd_ary_to_name: Mapping[Array, str], + sent_ary_to_name: Mapping[Array, str], + sptpo_ary_to_name: Mapping[Array, str], + local_recv_id_to_recv_node: Dict[CommunicationOpIdentifier, DistributedRecv], + local_send_id_to_send_node: Dict[CommunicationOpIdentifier, DistributedSend], + ) -> DistributedGraphPartition: + name_to_output = {} parts: Dict[PartId, DistributedGraphPart] = {} - dist_name_generator = UniqueNameGenerator(forced_prefix="_pt_dist_") - - for part in sorted(partition.parts.values(), - key=lambda k: sorted(k.output_names)): - comm_replacer = _DistributedCommReplacer(dist_name_generator) - part_results = { - var_name: comm_replacer(partition.var_name_to_result[var_name]) - for var_name in sorted(part.output_names)} - - dist_sends = [ - comm_replacer.map_distributed_send(send) - for send in pid_to_distributed_sends.get(part.pid, [])] - - part_results.update({ - name: send_node.data - for name, send_node in - comm_replacer.output_name_to_send_node.items()}) - - parts[part.pid] = DistributedGraphPart( - pid=part.pid, - needed_pids=part.needed_pids, - user_input_names=part.user_input_names, - partition_input_names=(part.partition_input_names - | frozenset(comm_replacer.input_name_to_recv_node)), - output_names=(part.output_names - | frozenset(comm_replacer.output_name_to_send_node)), - distributed_sends=dist_sends, - - input_name_to_recv_node=comm_replacer.input_name_to_recv_node, - output_name_to_send_node=comm_replacer.output_name_to_send_node) - - for name, val in part_results.items(): - assert name not in var_name_to_result - var_name_to_result[name] = val + for part_id, name_to_ouput in enumerate(name_to_output_per_part): + comm_replacer = _DistributedInputReplacer( + recvd_ary_to_name, sptpo_ary_to_name, name_to_ouput) + + for name, val in name_to_ouput.items(): + assert name not in name_to_output + name_to_output[name] = comm_replacer(val) + + comm_ids = part_comm_ids[part_id] + + name_to_send_nodes: Dict[str, List[DistributedSend]] = {} + for send_id in comm_ids.send_ids: + send_node = local_send_id_to_send_node[send_id] + name = sent_ary_to_name[send_node.data] + name_to_send_nodes.setdefault(name, []).append( + comm_replacer.map_distributed_send(send_node)) + + parts[part_id] = DistributedGraphPart( + pid=part_id, + needed_pids=frozenset({part_id - 1} if part_id else {}), + user_input_names=frozenset(comm_replacer.user_input_names), + partition_input_names=frozenset( + comm_replacer.partition_input_name_to_placeholder.keys()), + output_names=frozenset(name_to_ouput.keys()), + name_to_recv_node=Map({ + recvd_ary_to_name[local_recv_id_to_recv_node[recv_id]]: + local_recv_id_to_recv_node[recv_id] + for recv_id in comm_ids.recv_ids}), + name_to_send_nodes=Map(name_to_send_nodes)) result = DistributedGraphPartition( parts=parts, - var_name_to_result=var_name_to_result, - toposorted_part_ids=partition.toposorted_part_ids) - - if __debug__: - # Check disjointness again since we replaced a few nodes. - from pytato.partition import _check_partition_disjointness - _check_partition_disjointness(result) + name_to_output=name_to_output, + ) return result # }}} -# {{{ helpers for find_distributed_partition +# {{{ _LocalSendRecvDepGatherer -class _DistributedGraphPartitioner(GraphPartitioner): +def _send_to_comm_id( + local_rank: int, send: DistributedSend) -> CommunicationOpIdentifier: + if local_rank == send.dest_rank: + raise NotImplementedError("Self-sends are not currently allowed. " + f"(tag: '{send.comm_tag}')") - def __init__(self, get_part_id: Callable[[ArrayOrNames], PartId]) -> None: - super().__init__(get_part_id) - self.pid_to_dist_sends: Dict[PartId, List[DistributedSend]] = {} + return CommunicationOpIdentifier( + src_rank=local_rank, + dest_rank=send.dest_rank, + comm_tag=send.comm_tag) - def map_distributed_send_ref_holder( - self, expr: DistributedSendRefHolder, *args: Any) -> Any: - send_part_id = self.get_part_id(expr.send.data) - rec_send_data = self.rec(expr.send.data) - assert isinstance(rec_send_data, Array) - - self.pid_to_dist_sends.setdefault(send_part_id, []).append( - DistributedSend( - data=rec_send_data, - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag, - tags=expr.send.tags)) - - return self.rec(expr.passthrough_data) - def make_partition(self, outputs: DictOfNamedArrays) \ - -> DistributedGraphPartition: +def _recv_to_comm_id( + local_rank: int, recv: DistributedRecv) -> CommunicationOpIdentifier: + if local_rank == recv.src_rank: + raise NotImplementedError("Self-receives are not currently allowed. " + f"(tag: '{recv.comm_tag}')") - partition = super().make_partition(outputs) - return _partition_to_distributed_partition(partition, self.pid_to_dist_sends) + return CommunicationOpIdentifier( + src_rank=recv.src_rank, + dest_rank=local_rank, + comm_tag=recv.comm_tag) -class _MandatoryPartitionOutputsCollector(CombineMapper[FrozenSet[Array]]): - """ - Collects all nodes that, after partitioning, are necessarily outputs - of the partition to which they belong. - """ - def __init__(self) -> None: +class _LocalSendRecvDepGatherer( + CombineMapper[FrozenSet[CommunicationOpIdentifier]]): + def __init__(self, local_rank: int) -> None: super().__init__() - self.partition_outputs: Set[Array] = set() + self.local_comm_ids_to_needed_comm_ids: \ + Dict[CommunicationOpIdentifier, + FrozenSet[CommunicationOpIdentifier]] = {} - def combine(self, *args: FrozenSet[Array]) -> FrozenSet[Array]: - from functools import reduce + self.local_recv_id_to_recv_node: \ + Dict[CommunicationOpIdentifier, DistributedRecv] = {} + self.local_send_id_to_send_node: \ + Dict[CommunicationOpIdentifier, DistributedSend] = {} + + self.local_rank = local_rank + + def combine( + self, *args: FrozenSet[CommunicationOpIdentifier] + ) -> FrozenSet[CommunicationOpIdentifier]: return reduce(frozenset.union, args, frozenset()) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> FrozenSet[Array]: - return self.combine(frozenset([expr.send.data]), - super().map_distributed_send_ref_holder(expr)) + ) -> FrozenSet[CommunicationOpIdentifier]: + send_id = _send_to_comm_id(self.local_rank, expr.send) + + if send_id in self.local_send_id_to_send_node: + from pytato.distributed.verify import DuplicateSendError + raise DuplicateSendError(f"Multiple sends found for '{send_id}'") + + self.local_comm_ids_to_needed_comm_ids[send_id] = \ + self.rec(expr.send.data) + + self.local_send_id_to_send_node[send_id] = expr.send + + return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> FrozenSet[Array]: + def _map_input_base(self, expr: Array) -> FrozenSet[CommunicationOpIdentifier]: return frozenset() map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - map_distributed_recv = _map_input_base + def map_distributed_recv( + self, expr: DistributedRecv + ) -> FrozenSet[CommunicationOpIdentifier]: + recv_id = _recv_to_comm_id(self.local_rank, expr) -@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) -class _MaterializedArrayCollector(CachedWalkMapper): - """ - Collects all nodes that have to be materialized during code-generation. - """ - def __init__(self) -> None: - super().__init__() - self.materialized_arrays: Set[Array] = set() + if recv_id in self.local_recv_id_to_recv_node: + from pytato.distributed.verify import DuplicateRecvError + raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - # type-ignore-reason: dropped the extra `*args, **kwargs`. - def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] - return id(expr) + self.local_comm_ids_to_needed_comm_ids[recv_id] = frozenset() - # type-ignore-reason: dropped the extra `*args, **kwargs`. - def post_visit(self, expr: Any) -> None: # type: ignore[override] - from pytato.tags import ImplStored - from pytato.loopy import LoopyCallResult + self.local_recv_id_to_recv_node[recv_id] = expr - if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): - self.materialized_arrays.add(expr) + return frozenset({recv_id}) - if isinstance(expr, LoopyCallResult): - self.materialized_arrays.add(expr) - from pytato.loopy import LoopyCall - assert isinstance(expr._container, LoopyCall) - for _, subexpr in sorted(expr._container.bindings.items()): - if isinstance(subexpr, Array): - self.materialized_arrays.add(subexpr) - else: - assert isinstance(subexpr, SCALAR_CLASSES) +# }}} - if isinstance(expr, DictOfNamedArrays): - for _, subexpr in sorted(expr._data.items()): - assert isinstance(subexpr, Array) - self.materialized_arrays.add(subexpr) +# {{{ _schedule_comm_batches -class _DominantMaterializedPredecessorsCollector(Mapper): +def _schedule_comm_batches( + comm_ids_to_needed_comm_ids: CommunicationDepGraph + ) -> Sequence[AbstractSet[CommunicationOpIdentifier]]: + """For each :class:`CommunicationOpIdentifier`, determine the + 'round'/'batch' during which it will be performed. A 'batch' + of communication consists of sends and receives. Computation + occurs between batches. (So, from the perspective of the + :class:`DistributedGraphPartition`, communication batches + sit *between* parts.) """ - A Mapper whose mapper method for a node returns the materialized predecessors - just after the point the node is evaluated. - """ - def __init__(self, materialized_arrays: FrozenSet[Array]) -> None: - super().__init__() - self.materialized_arrays = materialized_arrays - self.cache: Dict[ArrayOrNames, FrozenSet[Array]] = {} + # FIXME: I'm an O(n^2) algorithm. - def _combine(self, values: Iterable[Array]) -> FrozenSet[Array]: - from functools import reduce - return reduce(frozenset.union, - (self.rec(v) for v in values), - frozenset()) + comm_batches: List[AbstractSet[CommunicationOpIdentifier]] = [] - # type-ignore reason: return type not compatible with Mapper.rec's type - def rec(self, expr: ArrayOrNames) -> FrozenSet[Array]: # type: ignore[override] - try: - return self.cache[expr] - except KeyError: - result: FrozenSet[Array] = super().rec(expr) - self.cache[expr] = result - return result + scheduled_comm_ids: Set[CommunicationOpIdentifier] = set() + comms_to_schedule = set(comm_ids_to_needed_comm_ids) - @cached_property - def direct_preds_getter(self) -> DirectPredecessorsGetter: - return DirectPredecessorsGetter() + all_comm_ids = frozenset(comm_ids_to_needed_comm_ids) - def _map_generic_node(self, expr: Array) -> FrozenSet[Array]: - direct_preds = self.direct_preds_getter(expr) + # FIXME In order for this to work, comm tags must be unique + while len(scheduled_comm_ids) < len(all_comm_ids): + comm_ids_this_batch = { + comm_id for comm_id in comms_to_schedule + if comm_ids_to_needed_comm_ids[comm_id] <= scheduled_comm_ids} - if expr in self.materialized_arrays: - return frozenset([expr]) - else: - return self._combine(direct_preds) - - map_placeholder = _map_generic_node - map_data_wrapper = _map_generic_node - map_size_param = _map_generic_node - - map_index_lambda = _map_generic_node - map_stack = _map_generic_node - map_concatenate = _map_generic_node - map_roll = _map_generic_node - map_axis_permutation = _map_generic_node - map_basic_index = _map_generic_node - map_contiguous_advanced_index = _map_generic_node - map_non_contiguous_advanced_index = _map_generic_node - map_reshape = _map_generic_node - map_einsum = _map_generic_node - map_distributed_recv = _map_generic_node - - def map_named_array(self, expr: NamedArray) -> FrozenSet[Array]: - raise NotImplementedError("only LoopyCallResult named array" - " supported for now.") - - def map_dict_of_named_arrays(self, expr: DictOfNamedArrays - ) -> FrozenSet[Array]: - raise NotImplementedError("Dict of named arrays not (yet) implemented") - - def map_loopy_call_result(self, expr: NamedArray) -> FrozenSet[Array]: - # ``loopy call result` is always materialized. However, make sure to - # traverse its arguments. - assert expr in self.materialized_arrays - return self._map_generic_node(expr) + if not comm_ids_this_batch: + raise CycleError("cycle detected in communication graph") + + scheduled_comm_ids.update(comm_ids_this_batch) + comms_to_schedule = comms_to_schedule - comm_ids_this_batch + + comm_batches.append(comm_ids_this_batch) + + return comm_batches + +# }}} - def map_distributed_send_ref_holder(self, - expr: DistributedSendRefHolder - ) -> FrozenSet[Array]: - return self.rec(expr.passthrough_data) +# {{{ _MaterializedArrayCollector @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) -class _DominantMaterializedPredecessorsRecorder(CachedWalkMapper): +class _MaterializedArrayCollector(CachedWalkMapper): """ - For each node in an expression graph, this mapper records the dominant - predecessors of each node of an expression graph into - :attr:`array_to_mat_preds`. + Collects all nodes that have to be materialized during code-generation. """ - def __init__(self, mat_preds_getter: Callable[[Array], FrozenSet[Array]] - ) -> None: + def __init__(self) -> None: super().__init__() - self.mat_preds_getter = mat_preds_getter - self.array_to_mat_preds: Dict[Array, FrozenSet[Array]] = {} + self.materialized_arrays: _OrderedSet[Array] = _OrderedSet() # type-ignore-reason: dropped the extra `*args, **kwargs`. def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] return id(expr) - @cached_property - def direct_preds_getter(self) -> DirectPredecessorsGetter: - return DirectPredecessorsGetter() - # type-ignore-reason: dropped the extra `*args, **kwargs`. def post_visit(self, expr: Any) -> None: # type: ignore[override] - from functools import reduce - if isinstance(expr, Array): - self.array_to_mat_preds[expr] = reduce( - frozenset.union, - (self.mat_preds_getter(pred) - for pred in self.direct_preds_getter(expr)), - frozenset()) - - -def _linearly_schedule_batches( - predecessors: Map[Array, FrozenSet[Array]]) -> Map[Array, int]: - """ - Used by :func:`find_distributed_partition`. Based on the dependencies in - *predecessors*, each node is assigned a time such that evaluating the array - at that point in time would not violate dependencies. This "time" or "batch - number" is then used as a partition ID. - """ - from functools import reduce - current_time = 0 - ary_to_time = {} - scheduled_nodes: Set[Array] = set() - all_nodes = frozenset(predecessors) + from pytato.tags import ImplStored + from pytato.loopy import LoopyCallResult - # assert that the keys contain all the nodes - assert reduce(frozenset.union, - predecessors.values(), - cast(FrozenSet[Array], frozenset())) <= frozenset(predecessors) + if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): + self.materialized_arrays.add(expr) - while len(scheduled_nodes) < len(all_nodes): - # {{{ eagerly schedule nodes whose predecessors have been scheduled + if isinstance(expr, LoopyCallResult): + self.materialized_arrays.add(expr) + from pytato.loopy import LoopyCall + assert isinstance(expr._container, LoopyCall) + for _, subexpr in sorted(expr._container.bindings.items()): + if isinstance(subexpr, Array): + self.materialized_arrays.add(subexpr) + else: + assert isinstance(subexpr, SCALAR_CLASSES) - nodes_to_schedule = {node - for node, preds in predecessors.items() - if ((node not in scheduled_nodes) - and (preds <= scheduled_nodes))} - for node in nodes_to_schedule: - assert node not in ary_to_time - ary_to_time[node] = current_time +# }}} - scheduled_nodes.update(nodes_to_schedule) - current_time += 1 +# {{{ _set_dict_union_mpi - # }}} +def _set_dict_union_mpi( + dict_a: Mapping[_KeyT, FrozenSet[_ValueT]], + dict_b: Mapping[_KeyT, FrozenSet[_ValueT]], + mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, FrozenSet[_ValueT]]: + assert mpi_data_type is None + result = dict(dict_a) + for key, values in dict_b.items(): + result[key] = result.get(key, frozenset()) | values + return result - assert set(ary_to_time.values()) == set(range(current_time)) - return Map(ary_to_time) +# }}} -def _assign_materialized_arrays_to_part_id( - materialized_arrays: FrozenSet[Array], - array_to_output_deps: Mapping[Array, FrozenSet[Array]], - outputs_to_part_id: Mapping[Array, int] -) -> Map[Array, int]: - """ - Returns a mapping from a materialized array to the part's ID where all the - inputs of the array expression are available. +# {{{ find_distributed_partition - Invoked as an intermediate step in :func:`find_distributed_partition`. +def find_distributed_partition( + mpi_communicator: mpi4py.MPI.Comm, + outputs: DictOfNamedArrays + ) -> DistributedGraphPartition: + r""" + Compute a :class:DistributedGraphPartition` (for use with + :func:`execute_distributed_partition`) that evaluates the + same result as *outputs*, such that: + + - communication only happens at the beginning and end of + each :class:`DistributedGraphPart`, and + - the partition introduces no circular dependencies between parts, + mediated by either local data flow or off-rank communication. + + .. warning:: + + This is an MPI-collective operation. + + The following sections describe the (non-binding, as far as documentation + is concerned) algorithm behind the partitioner. + + .. rubric:: Preliminaries + + We identify a communication operation (consisting of a pair of a send + and a receive) by a + :class:`~pytato.distributed.partition.CommunicationOpIdentifier`. We keep + graphs of these in + :class:`~pytato.distributed.partition.CommunicationDepGraph`. + + If ``graph`` is a + :class:`~pytato.distributed.partition.CommunicationDepGraph`, then ``b in + graph[a]`` means that, in order to initiate the communication operation + identified by :class:`~pytato.distributed.partition.CommunicationOpIdentifier` + ``a``, the communication operation identified by + :class:`~pytato.distributed.partition.CommunicationOpIdentifier` ``b`` must + be completed. + I.e. the nodes are "communication operations", i.e. pairs of + send/receive. Edges represent (rank-local) data flow between them. + + .. rubric:: Step 1: Build a global graph of data flow between communication + operations + + As a first step, each rank receives a copy of global + :class:`~pytato.distributed.partition.CommunicationDepGraph`, as described + above. This becomes ``comm_ids_to_needed_comm_ids``. + + .. rubric:: Step 2: Obtain a "schedule" of "communication batches" + + On rank 0, compute and broadcast a topological order of + ``comm_ids_to_needed_comm_ids``. The result of this is + ``comm_batches``, a sequence of sets of + :class:`~pytato.distributed.partition.CommunicationOpIdentifier` + instances, identifying sets of communication operations expected + to complete *between* parts of the computation. (I.e. computation + will occur before the first communication batch, then between the + first and second, and so on.) .. note:: - In this heuristic we compute the materialized array as soon as its - inputs are available. In some cases it might be worth exploring - schedules where the evaluation of an array is delayed until one of - its users demand it. - """ - - materialized_array_to_part_id: Dict[Array, int] = {} + An important restriction of this scheme is that a linear order + of communication batches is obtained, meaning that, typically, + no overlap of computation and communication occurs. - for ary in materialized_arrays: - materialized_array_to_part_id[ary] = max( - (outputs_to_part_id[dep] - for dep in array_to_output_deps[ary]), - default=-1) + 1 + .. rubric:: Step 3: Create rank-local part descriptors - return Map(materialized_array_to_part_id) + On each rank, we next rewrite the communication batches into computation + parts, each identified by a ``_PartCommIDs`` structure, which + gathers receives that need to complete *before* the computation on a part + can begin and sends that can begin once computation on a part + is complete. + .. rubric:: Step 4: Assign materialized arrays to parts -def _get_array_to_dominant_materialized_deps( - outputs: DictOfNamedArrays, - materialized_arrays: FrozenSet[Array]) -> Map[Array, FrozenSet[Array]]: - """ - Returns a mapping from each node in the DAG *outputs* to a :class:`frozenset` - of its dominant materialized predecessors. - """ + "Stored" arrays are those whose value will be computed and stored + in memory. This includes the following: - dominant_materialized_deps = _DominantMaterializedPredecessorsCollector( - materialized_arrays) - dominant_materialized_deps_recorder = ( - _DominantMaterializedPredecessorsRecorder(dominant_materialized_deps)) - dominant_materialized_deps_recorder(outputs) - return Map(dominant_materialized_deps_recorder.array_to_mat_preds) + - Arrays tagged :class:`~pytato.tags.ImplStored` by prior processing of the DAG, + - arrays being sent (because we need to hand a buffer to MPI), + - arrays being received (because MPI puts the received data + in memory) + - Overall outputs of the computation. + By contrast, the code below uses the word "materialized" only for arrays of + the first type (tagged :class:`~pytato.tags.ImplStored`), so that 'stored' is a + superset of 'materialized'. -def _get_materialized_arrays_promoted_to_partition_outputs( - ary_to_dominant_stored_preds: Mapping[Array, FrozenSet[Array]], - stored_ary_to_part_id: Mapping[Array, int], - materialized_arrays: FrozenSet[Array] -) -> FrozenSet[Array]: - """ - Returns a :class:`frozenset` of materialized arrays that are used by - multiple partitions. Materialized arrays that are used by multiple - partitions are special in that they *must* be promoted as the outputs of a - partition. + In addition, data computed by one *part* (in the above sense) of the + computation and used by another must be in memory. Evaluating and storing + temporary arrays is expensive, and so we try to minimize the number of + times that that this occurs as part of the partitioning. This is done by + relying on already-stored arrays as much as possible and recomputing any + intermediate results needed in, say, an originating and a consuming part. - Invoked as an intermediate step in :func:`find_distributed_partition`. + We begin this process by assigning each materialized + array to a part in which it is computed, based on the part in which + data depending on such arrays is sent. This choice implies that these + computations occur as late as possible. - :arg ary_to_dominant_stored_preds: A mapping from array to the dominant - stored predecessors. A stored array can be either a mandatory partition - output or a materialized array as indicated by the user. - """ - materialized_ary_to_part_id_users: Dict[Array, Set[int]] = {} + .. rubric:: Step 5: Promote stored arrays to part outputs if needed - for ary in stored_ary_to_part_id: - stored_preds = ary_to_dominant_stored_preds[ary] - for pred in stored_preds: - if pred in materialized_arrays: - (materialized_ary_to_part_id_users - .setdefault(pred, set()) - .add(stored_ary_to_part_id[ary])) + In :class:`DistributedGraphPart`, our description of the partitioned + computation, each part can declare named 'outputs' that can be used + by subsequent parts. Stored arrays are promoted to part outputs + if they have users in parts other than the one in which they + are computed. - return frozenset({ary - for ary, users in materialized_ary_to_part_id_users.items() - if users != {stored_ary_to_part_id[ary]}}) + .. rubric:: Step 6:: Rewrite the DAG into its parts + In the final step, we traverse the DAG to apply the following changes: -@attrs.define(frozen=True, eq=True, repr=True) -class PartIDTag(UniqueTag): + - Replace :class:`DistributedRecv` nodes with placeholders for names + assigned in :attr:`DistributedGraphPart.name_to_recv_node`. + - Replace references to out-of-part stored arrays with + :class:`~pytato.array.Placeholder` instances. + - Gather sent arrays into + assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ - A tag applicable to a :class:`pytato.Array` recording to which part the - array belongs. - """ - part_id: int + from pytato.transform import SubsetDependencyMapper + import mpi4py.MPI as MPI -class _PartIDTagAssigner(CopyMapperWithExtraArgs): - """ - Used by :func:`find_distributed_partition` to assign each array - node a :class:`PartIDTag`. - """ - def __init__(self, - stored_array_to_part_id: Mapping[Array, int], - partition_outputs: FrozenSet[Array]) -> None: - self.stored_array_to_part_id = stored_array_to_part_id - self.partition_outputs = partition_outputs - - # type-ignore reason: incompatible attribute type wrt base. - self._cache: Dict[Tuple[ArrayOrNames, int], - ArrayOrNames] = {} # type: ignore[assignment] - - # type-ignore-reason: incompatible with super class - def get_cache_key(self, # type: ignore[override] - expr: ArrayOrNames, - user_part_id: int - ) -> Tuple[ArrayOrNames, int]: - - return (expr, user_part_id) - - # type-ignore-reason: incompatible with super class - def rec(self, # type: ignore[override] - expr: CopyMapperResultT, - user_part_id: int) -> CopyMapperResultT: - key = self.get_cache_key(expr, user_part_id) - try: - # type-ignore-reason: parametric dicts are not a thing in typing module - return self._cache[key] # type: ignore[return-value] - except KeyError: - if isinstance(expr, Array): - if expr in self.stored_array_to_part_id: - assert ((self.stored_array_to_part_id[expr] - == user_part_id) - or expr in self.partition_outputs) - # at stored array the part id changes - user_part_id = self.stored_array_to_part_id[expr] + local_rank = mpi_communicator.rank - expr = expr.tagged(PartIDTag(user_part_id)) + # {{{ find comm_ids_to_needed_comm_ids - result = super().rec(expr, user_part_id) - self._cache[key] = result - return result + lsrdg = _LocalSendRecvDepGatherer(local_rank=local_rank) + lsrdg(outputs) + local_comm_ids_to_needed_comm_ids = \ + lsrdg.local_comm_ids_to_needed_comm_ids - # type-ignore-reason: incompatible with super class - def __call__(self, # type: ignore[override] - expr: CopyMapperResultT, - user_part_id: int) -> CopyMapperResultT: - return self.rec(expr, user_part_id) + set_dict_union_mpi_op = MPI.Op.Create( + # type ignore reason: mpi4py misdeclares op functions as returning + # None. + _set_dict_union_mpi, # type: ignore[arg-type] + commute=True) + try: + comm_ids_to_needed_comm_ids = mpi_communicator.allreduce( + local_comm_ids_to_needed_comm_ids, set_dict_union_mpi_op) + finally: + set_dict_union_mpi_op.Free() + # }}} -def _remove_part_id_tag(ary: ArrayOrNames) -> Array: - assert isinstance(ary, Array) + # {{{ make batches out of comm_ids_to_needed_comm_ids - # Spurious assignment because of - # https://github.com/python/mypy/issues/12626 - result: Array = ary.without_tags(ary.tags_of_type(PartIDTag)) - return result + if mpi_communicator.rank == 0: + # The comm_batches correspond one-to-one to DistributedGraphParts + # in the output. + try: + comm_batches = _schedule_comm_batches(comm_ids_to_needed_comm_ids) + except Exception as exc: + mpi_communicator.bcast(exc) + raise + else: + mpi_communicator.bcast(comm_batches) + else: + comm_batches_or_exc = mpi_communicator.bcast(None) + if isinstance(comm_batches_or_exc, Exception): + raise comm_batches_or_exc -# }}} + comm_batches = cast( + Sequence[AbstractSet[CommunicationOpIdentifier]], + comm_batches_or_exc) + # }}} -# {{{ find_distributed_partition + # {{{ create (local) parts out of batch ids + + part_comm_ids: List[_PartCommIDs] = [] + + if comm_batches: + recv_ids: FrozenSet[CommunicationOpIdentifier] = frozenset() + for batch in comm_batches: + send_ids = frozenset( + comm_id for comm_id in batch + if comm_id.src_rank == local_rank) + if recv_ids or send_ids: + part_comm_ids.append( + _PartCommIDs( + recv_ids=recv_ids, + send_ids=send_ids)) + # These go into the next part + recv_ids = frozenset( + comm_id for comm_id in batch + if comm_id.dest_rank == local_rank) + if recv_ids: + part_comm_ids.append( + _PartCommIDs( + recv_ids=recv_ids, + send_ids=frozenset())) + else: + part_comm_ids.append( + _PartCommIDs( + recv_ids=frozenset(), + send_ids=frozenset())) + + nparts = len(part_comm_ids) -def find_distributed_partition(outputs: DictOfNamedArrays - ) -> DistributedGraphPartition: - """ - Partitions *outputs* into parts. Between two parts communication - statements (sends/receives) are scheduled. + if __debug__: + from pytato.distributed.verify import MissingRecvError, MissingSendError - .. note:: + for part in part_comm_ids: + for recv_id in part.recv_ids: + if recv_id not in lsrdg.local_recv_id_to_recv_node: + raise MissingRecvError(f"no receive for '{recv_id}'") + for send_id in part.send_ids: + if send_id not in lsrdg.local_send_id_to_send_node: + raise MissingSendError(f"no send for '{send_id}'") - The partitioning of a DAG generally does not have a unique solution. - The heuristic employed by this partitioner is as follows: - - 1. The data contained in :class:`~pytato.DistributedSend` are marked as - *mandatory part outputs*. - 2. Based on the dependencies in *outputs*, a DAG is constructed with - only the mandatory part outputs as the nodes. - 3. Using a topological sort the mandatory part outputs are assigned a - "time" (an integer) such that evaluating these outputs at that time - would preserve dependencies. We maximize the number of part outputs - scheduled at a each "time". This requirement ensures our topological - sort is deterministic. - 4. We then turn our attention to the other arrays that are allocated to a - buffer. These are the materialized arrays and belong to one of the - following classes: - - An :class:`~pytato.Array` tagged with :class:`pytato.tags.ImplStored`. - - The expressions in a :class:`~pytato.DictOfNamedArrays`. - 5. Based on *outputs*, we compute the predecessors of a materialized - that were a part of the mandatory part outputs. A materialized array - is scheduled to be evaluated in a part as soon as all of its inputs - are available. Note that certain inputs (like - :class:`~pytato.DistributedRecv`) might not be available until - certain mandatory part outputs have been evaluated. - 6. From *outputs*, we can construct a DAG comprising only of mandatory - part outputs and materialized arrays. We mark all materialized - arrays that are being used by nodes in a part that's not the one in - which the materialized array itself was evaluated. Such materialized - arrays are also realized as part outputs. This is done to avoid - recomputations. - - Knobs to tweak the partition: - - 1. By removing dependencies between the mandatory part outputs, the - resulting DAG would lead to fewer number of parts and parts with - more number of nodes in them. Similarly, adding dependencies between - the part outputs would lead to smaller parts. - 2. Tagging nodes with :class:~pytato.tags.ImplStored` would help in - avoiding re-computations. - """ - from pytato.transform import SubsetDependencyMapper - from pytato.array import make_dict_of_named_arrays - from pytato.partition import find_partition + comm_id_to_part_id = { + comm_id: ipart + for ipart, comm_ids in enumerate(part_comm_ids) + for comm_id in comm_ids.send_ids | comm_ids.recv_ids} - # {{{ get partitioning helper data corresponding to the DAG + # }}} - partition_outputs = _MandatoryPartitionOutputsCollector()(outputs) + # {{{ assign each compulsorily materialized array to a part - # materialized_arrays: "extra" arrays that must be stored in a buffer materialized_arrays_collector = _MaterializedArrayCollector() materialized_arrays_collector(outputs) - materialized_arrays = frozenset( - materialized_arrays_collector.materialized_arrays) - partition_outputs - - # }}} - - dep_mapper = SubsetDependencyMapper(partition_outputs) - - # {{{ compute a dependency graph between outputs, schedule and partition - output_to_deps = Map({partition_out: (dep_mapper(partition_out) - - frozenset([partition_out])) - for partition_out in partition_outputs}) + # The sets of arrays below must have a deterministic order in order to ensure + # that the resulting partition is also deterministic + + sent_arrays = _OrderedSet( + send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) + + received_arrays = _OrderedSet(lsrdg.local_recv_id_to_recv_node.values()) + + # While receive nodes may be marked as materialized, we shouldn't be + # including them here because we're using them (along with the send nodes) + # as anchors to place *other* materialized data into the batches. + # We could allow sent *arrays* to be included here because they are distinct + # from send *nodes*, but we choose to exclude them in order to simplify the + # processing below. + materialized_arrays = ( + materialized_arrays_collector.materialized_arrays + - received_arrays + - sent_arrays) + + # "mso" for "materialized/sent/output" + output_arrays = _OrderedSet(outputs._data.values()) + mso_arrays = materialized_arrays | sent_arrays | output_arrays + + # FIXME: This gathers up materialized_arrays recursively, leading to + # result sizes potentially quadratic in the number of materialized arrays. + mso_array_dep_mapper = SubsetDependencyMapper(frozenset(mso_arrays)) + + mso_ary_to_first_dep_send_part_id: Dict[Array, int] = { + ary: nparts + for ary in mso_arrays} + for send_id, send_node in lsrdg.local_send_id_to_send_node.items(): + for ary in mso_array_dep_mapper(send_node.data): + mso_ary_to_first_dep_send_part_id[ary] = min( + mso_ary_to_first_dep_send_part_id[ary], + comm_id_to_part_id[send_id]) - output_to_part_id = _linearly_schedule_batches(output_to_deps) + if __debug__: + recvd_array_dep_mapper = SubsetDependencyMapper(frozenset(received_arrays)) + + mso_ary_to_last_dep_recv_part_id: Dict[Array, int] = { + ary: max( + (comm_id_to_part_id[ + _recv_to_comm_id(local_rank, + cast(DistributedRecv, recvd_ary))] + for recvd_ary in recvd_array_dep_mapper(ary)), + default=-1) + for ary in mso_arrays + } + + assert all( + ( + mso_ary_to_last_dep_recv_part_id[ary] + <= mso_ary_to_first_dep_send_part_id[ary]) + for ary in mso_arrays), \ + "unable to find suitable part for materialized or output array" + + # FIXME: (Seemingly) arbitrary decision, subject to future investigation. + # Evaluation of materialized arrays is pushed as late as possible, + # in order to minimize the amount of computation that might prevent + # data from being sent. + mso_ary_to_part_id: Dict[Array, int] = { + ary: min( + mso_ary_to_first_dep_send_part_id[ary], + nparts-1) + for ary in mso_arrays} # }}} - # {{{ assign each materialized array a partition ID in which it will be placed - - materialized_array_to_output_deps = Map({ary: (dep_mapper(ary) - - frozenset([ary])) - for ary in materialized_arrays}) - materialized_ary_to_part_id = _assign_materialized_arrays_to_part_id( - materialized_arrays, - materialized_array_to_output_deps, - output_to_part_id) + recvd_ary_to_part_id: Dict[Array, int] = { + recvd_ary: ( + comm_id_to_part_id[ + _recv_to_comm_id(local_rank, recvd_ary)]) + for recvd_ary in received_arrays} - assert frozenset(materialized_ary_to_part_id) == materialized_arrays + # "Materialized" arrays are arrays that are tagged with ImplStored, + # i.e. "the outside world" (from the perspective of the partitioner) + # has decided that these arrays will live in memory. + # + # In addition, arrays that are sent and received must also live in memory. + # So, "stored" = "materialized" ∪ "overall outputs" ∪ "communicated" + stored_ary_to_part_id = mso_ary_to_part_id.copy() + stored_ary_to_part_id.update(recvd_ary_to_part_id) - # }}} + assert all(0 <= part_id < nparts + for part_id in stored_ary_to_part_id.values()) - stored_ary_to_part_id = materialized_ary_to_part_id.update(output_to_part_id) + stored_arrays = _OrderedSet(stored_ary_to_part_id) - # {{{ find which materialized arrays have users in multiple parts - # (and promote them to part outputs) + # {{{ find which stored arrays should become part outputs + # (because they are used in not just their local part, but also others) - ary_to_dominant_materialized_deps = ( - _get_array_to_dominant_materialized_deps(outputs, - (materialized_arrays - | partition_outputs))) + direct_preds_getter = DirectPredecessorsGetter() - materialized_arrays_realized_as_partition_outputs = ( - _get_materialized_arrays_promoted_to_partition_outputs( - ary_to_dominant_materialized_deps, - stored_ary_to_part_id, - materialized_arrays)) + def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]: + materialized_preds: _OrderedSet[Array] = _OrderedSet() + for pred in direct_preds_getter(ary): + if pred in materialized_arrays: + materialized_preds.add(pred) + else: + materialized_preds |= get_materialized_predecessors(pred) + return materialized_preds + + stored_arrays_promoted_to_part_outputs = { + stored_pred + for stored_ary in stored_arrays + for stored_pred in get_materialized_predecessors(stored_ary) + if (stored_ary_to_part_id[stored_ary] + != stored_ary_to_part_id[stored_pred]) + } # }}} - # {{{ tag each node with its part ID + # Don't be tempted to put outputs in _array_names; the mapping from output array + # to name may not be unique + _array_name_gen = UniqueNameGenerator(forced_prefix="_pt_dist_") + _array_names: Dict[Array, str] = {} - # Why is this necessary? (I.e. isn't the mapping *stored_ary_to_part_id* enough?) - # By assigning tags we also duplicate the non-materialized nodes that are - # to be made available in multiple parts. Parts being disjoint is one of - # the requirements of *find_partition*. - part_id_tag_assigner = _PartIDTagAssigner( - stored_ary_to_part_id, - partition_outputs | materialized_arrays_realized_as_partition_outputs) - - partitioned_outputs = make_dict_of_named_arrays({ - name: part_id_tag_assigner(subexpr, stored_ary_to_part_id[subexpr]) - for name, subexpr in outputs._data.items()}) - - # }}} - - def get_part_id(expr: ArrayOrNames) -> int: - if not isinstance(expr, Array): - raise NotImplementedError("find_distributed_partition" - " cannot partition DictOfNamedArrays") - assert isinstance(expr, Array) - tag, = expr.tags_of_type(PartIDTag) - assert isinstance(tag, PartIDTag) - return tag.part_id - - gp = cast(DistributedGraphPartition, - find_partition(partitioned_outputs, - get_part_id, - _DistributedGraphPartitioner)) - - # Remove PartIDTag from arrays that may be returned from the evaluation - # of the partitioned graph. If we don't, those may end up on inputs to - # another graph, which may also get partitioned, which will endlessly - # confuse that subsequent partitioning process. In addition, those - # tags may cause arrays to look spuriously different, defeating - # caching. - # See https://github.com/inducer/pytato/issues/307 for further discussion. - - # Note that it does not suffice to remove those tags from just, say, - # var_name_to_result: This may produce inconsistent Placeholder instances. - # For the same reason, we need to use the same mapper for all nodes. - from pytato.transform import CachedMapAndCopyMapper - cmac = CachedMapAndCopyMapper(_remove_part_id_tag) - - def map_array(ary: Array) -> Array: - result = cmac(ary) - assert isinstance(result, Array) - return result + def gen_array_name(ary: Array) -> str: + name = _array_names.get(ary) + if name is not None: + return name + else: + name = _array_name_gen() + _array_names[ary] = name + return name + + recvd_ary_to_name: Dict[Array, str] = { + ary: gen_array_name(ary) + for ary in received_arrays} + + name_to_output_per_part: List[Dict[str, Array]] = [{} for _pid in range(nparts)] + + for name, ary in outputs._data.items(): + pid = stored_ary_to_part_id[ary] + name_to_output_per_part[pid][name] = ary + + sent_ary_to_name: Dict[Array, str] = {} + for ary in sent_arrays: + pid = stored_ary_to_part_id[ary] + name = gen_array_name(ary) + sent_ary_to_name[ary] = name + name_to_output_per_part[pid][name] = ary + + sptpo_ary_to_name: Dict[Array, str] = {} + for ary in stored_arrays_promoted_to_part_outputs: + pid = stored_ary_to_part_id[ary] + name = gen_array_name(ary) + sptpo_ary_to_name[ary] = name + name_to_output_per_part[pid][name] = ary + + partition = _make_distributed_partition( + name_to_output_per_part, + part_comm_ids, + recvd_ary_to_name, + sent_ary_to_name, + sptpo_ary_to_name, + lsrdg.local_recv_id_to_recv_node, + lsrdg.local_send_id_to_send_node) + + from pytato.distributed.verify import _run_partition_diagnostics + _run_partition_diagnostics(outputs, partition) - def map_send(send: DistributedSend) -> DistributedSend: - return send.copy(data=cmac(send.data)) + if __debug__: + # Avoid potentially confusing errors if one rank manages to continue + # when another is not able. + mpi_communicator.barrier() - return _map_distributed_graph_partition_nodes(map_array, map_send, gp) + return partition # }}} diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py index c0d45f9..4fa419c 100644 --- a/pytato/distributed/tags.py +++ b/pytato/distributed/tags.py @@ -63,11 +63,12 @@ def number_distributed_tags( tags = frozenset({ recv.comm_tag for part in partition.parts.values() - for recv in part.input_name_to_recv_node.values() + for recv in part.name_to_recv_node.values() } | { send.comm_tag for part in partition.parts.values() - for send in part.output_name_to_send_node.values()}) + for sends in part.name_to_send_nodes.values() + for send in sends}) from mpi4py import MPI @@ -110,17 +111,19 @@ def number_distributed_tags( return DistributedGraphPartition( parts={ pid: replace(part, - input_name_to_recv_node={ + name_to_recv_node={ name: recv.copy(comm_tag=sym_tag_to_int_tag[recv.comm_tag]) - for name, recv in part.input_name_to_recv_node.items()}, - output_name_to_send_node={ - name: send.copy(comm_tag=sym_tag_to_int_tag[send.comm_tag]) - for name, send in part.output_name_to_send_node.items()}, + for name, recv in part.name_to_recv_node.items()}, + name_to_send_nodes={ + name: [ + send.copy(comm_tag=sym_tag_to_int_tag[send.comm_tag]) + for send in sends] + for name, sends in part.name_to_send_nodes.items()}, ) for pid, part in partition.parts.items() }, - var_name_to_result=partition.var_name_to_result, - toposorted_part_ids=partition.toposorted_part_ids), next_tag + name_to_output=partition.name_to_output, + ), next_tag # }}} diff --git a/pytato/distributed/verify.py b/pytato/distributed/verify.py index e37d8a8..c324b67 100644 --- a/pytato/distributed/verify.py +++ b/pytato/distributed/verify.py @@ -1,4 +1,9 @@ """ +Verification +------------ + +.. autoexception:: PartitionInducedCycleError + .. currentmodule:: pytato .. autofunction:: verify_distributed_partition """ @@ -30,17 +35,20 @@ THE SOFTWARE. """ -from typing import Any, FrozenSet, Dict, Set, Optional, Sequence, TYPE_CHECKING +from typing import Any, List, FrozenSet, Dict, Set, Optional, Sequence, TYPE_CHECKING +import attrs import numpy as np +from pymbolic.mapper.optimize import optimize_mapper + +from pytato.array import DictOfNamedArrays, make_dict_of_named_arrays, Placeholder +from pytato.transform import ArrayOrNames, CachedWalkMapper from pytato.distributed.nodes import CommTagType, DistributedRecv -from pytato.partition import PartId -from pytato.distributed.partition import DistributedGraphPartition +from pytato.distributed.partition import ( + PartId, DistributedGraphPartition, CommunicationOpIdentifier) from pytato.array import ShapeType -import attrs - import logging logger = logging.getLogger(__name__) @@ -81,25 +89,25 @@ class _SummarizedDistributedGraphPart: user_input_names: FrozenSet[_DistributedName] partition_input_names: FrozenSet[_DistributedName] output_names: FrozenSet[_DistributedName] - input_name_to_recv_node: Dict[_DistributedName, DistributedRecv] - output_name_to_send_node: Dict[_DistributedName, _SummarizedDistributedSend] + name_to_recv_node: Dict[_DistributedName, DistributedRecv] + name_to_send_nodes: Dict[_DistributedName, List[_SummarizedDistributedSend]] @property def rank(self) -> int: return self.pid.rank - -@attrs.define(frozen=True) -class _CommIdentifier: - src_rank: int - dest_rank: int - comm_tag: CommTagType - # }}} # {{{ errors +class PartitionInducedCycleError(AssertionError): + """Raised by if the partitioning (e.g. via + :func:`~pytato.find_distributed_partition`) erroneously induced a cycle in the + graph of partitions. + """ + + class DistributedPartitionVerificationError(ValueError): pass @@ -122,6 +130,74 @@ class MissingRecvError(DistributedPartitionVerificationError): # }}} +# {{{ _check_partition_disjointness + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class _SeenNodesWalkMapper(CachedWalkMapper): + def __init__(self) -> None: + super().__init__() + self.seen_nodes: Set[ArrayOrNames] = set() + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] + return id(expr) + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def visit(self, expr: ArrayOrNames) -> bool: # type: ignore[override] + super().visit(expr) + self.seen_nodes.add(expr) + return True + + +def _check_partition_disjointness(partition: DistributedGraphPartition) -> None: + part_id_to_nodes: Dict[PartId, Set[ArrayOrNames]] = {} + + for part in partition.parts.values(): + mapper = _SeenNodesWalkMapper() + for out_name in part.output_names: + mapper(partition.name_to_output[out_name]) + + # FIXME This check won't do much unless we successfully visit + # all the nodes, but we're not currently checking that. + for my_node in mapper.seen_nodes: + for other_part_id, other_node_set in part_id_to_nodes.items(): + # Placeholders represent values computed in one partition + # and used in one or more other ones. As a result, the + # same placeholder may occur in more than one partition. + if not (isinstance(my_node, Placeholder) + or my_node not in other_node_set): + raise RuntimeError( + "Partitions not disjoint: " + f"{my_node.__class__.__name__} (id={hex(id(my_node))}) " + f"in both '{part.pid}' and '{other_part_id}'" + f"{part.output_names=} " + f"{partition.parts[other_part_id].output_names=} ") + + part_id_to_nodes[part.pid] = mapper.seen_nodes + +# }}} + + +# {{{ _run_partition_diagnostics + +def _run_partition_diagnostics( + outputs: DictOfNamedArrays, gp: DistributedGraphPartition) -> None: + # FIXME: Is it reasonable to require this? + # if __debug__: + # _check_partition_disjointness(gp) + + from pytato.analysis import get_num_nodes + num_nodes_per_part = [get_num_nodes(make_dict_of_named_arrays( + {x: gp.name_to_output[x] for x in part.output_names})) + for part in gp.parts.values()] + + logger.info(f"find_partition: Split {get_num_nodes(outputs)} nodes into " + f"{len(gp.parts)} parts, with {num_nodes_per_part} nodes in each " + "partition.") + +# }}} + + # {{{ verify_distributed_partition def verify_distributed_partition(mpi_communicator: mpi4py.MPI.Comm, @@ -162,17 +238,18 @@ def verify_distributed_partition(mpi_communicator: mpi4py.MPI.Comm, for name in part.partition_input_names]), output_names=frozenset([_DistributedName(my_rank, name) for name in part.output_names]), - input_name_to_recv_node={_DistributedName(my_rank, name): recv - for name, recv in part.input_name_to_recv_node.items()}, - output_name_to_send_node={ - _DistributedName(my_rank, name): - _SummarizedDistributedSend( - src_rank=my_rank, - dest_rank=send.dest_rank, - comm_tag=send.comm_tag, - shape=send.data.shape, - dtype=send.data.dtype) - for name, send in part.output_name_to_send_node.items()}) + name_to_recv_node={_DistributedName(my_rank, name): recv + for name, recv in part.name_to_recv_node.items()}, + name_to_send_nodes={ + _DistributedName(my_rank, name): [ + _SummarizedDistributedSend( + src_rank=my_rank, + dest_rank=send.dest_rank, + comm_tag=send.comm_tag, + shape=send.data.shape, + dtype=send.data.dtype) + for send in sends] + for name, sends in part.name_to_send_nodes.items()}) # Gather the _SummarizedDistributedGraphPart's to rank 0 all_summarized_parts_gathered: Optional[ @@ -194,32 +271,45 @@ def verify_distributed_partition(mpi_communicator: mpi4py.MPI.Comm, needed_pid: _DistributedPartId) -> None: pid_to_needed_pids.setdefault(pid, set()).add(needed_pid) - all_recvs: Set[_CommIdentifier] = set() + all_recvs: Set[CommunicationOpIdentifier] = set() # {{{ gather information on who produces output - output_to_defining_pid: Dict[_DistributedName, _DistributedPartId] = {} + name_to_computing_pid: Dict[_DistributedName, _DistributedPartId] = {} for sumpart in all_summarized_parts.values(): for out_name in sumpart.output_names: - assert out_name not in output_to_defining_pid - output_to_defining_pid[out_name] = sumpart.pid + assert out_name not in name_to_computing_pid + name_to_computing_pid[out_name] = sumpart.pid # }}} - # {{{ gather information on senders + # {{{ gather information on who receives which names - comm_id_to_sending_pid: Dict[_CommIdentifier, _DistributedPartId] = {} + name_to_receiving_pid: Dict[_DistributedName, _DistributedPartId] = {} for sumpart in all_summarized_parts.values(): - for sumsend in sumpart.output_name_to_send_node.values(): - comm_id = _CommIdentifier( - src_rank=sumsend.src_rank, - dest_rank=sumsend.dest_rank, - comm_tag=sumsend.comm_tag) + for recv_name in sumpart.name_to_recv_node: + assert recv_name not in name_to_computing_pid + assert recv_name not in name_to_receiving_pid + name_to_receiving_pid[recv_name] = sumpart.pid + + # }}} + + # {{{ gather information on senders - if comm_id in comm_id_to_sending_pid: - raise DuplicateSendError( - f"duplicate send for comm id: '{comm_id}'") - comm_id_to_sending_pid[comm_id] = sumpart.pid + comm_id_to_sending_pid: \ + Dict[CommunicationOpIdentifier, _DistributedPartId] = {} + for sumpart in all_summarized_parts.values(): + for sumsends in sumpart.name_to_send_nodes.values(): + for sumsend in sumsends: + comm_id = CommunicationOpIdentifier( + src_rank=sumsend.src_rank, + dest_rank=sumsend.dest_rank, + comm_tag=sumsend.comm_tag) + + if comm_id in comm_id_to_sending_pid: + raise DuplicateSendError( + f"duplicate send for comm id: '{comm_id}'") + comm_id_to_sending_pid[comm_id] = sumpart.pid # }}} @@ -230,8 +320,8 @@ def verify_distributed_partition(mpi_communicator: mpi4py.MPI.Comm, # Loop through all receives, assert that combination of # (src_rank, dest_rank, tag) is unique. - for dname, dist_recv in sumpart.input_name_to_recv_node.items(): - comm_id = _CommIdentifier( + for dname, dist_recv in sumpart.name_to_recv_node.items(): + comm_id = CommunicationOpIdentifier( src_rank=dist_recv.src_rank, dest_rank=dname.rank, comm_tag=dist_recv.comm_tag) @@ -252,12 +342,23 @@ def verify_distributed_partition(mpi_communicator: mpi4py.MPI.Comm, # Add edges between output_names and partition_input_names (intra-rank) for input_name in sumpart.partition_input_names: - # Input names from recv nodes have no corresponding output_name - if input_name in sumpart.input_name_to_recv_node.keys(): - continue - defining_pid = output_to_defining_pid[input_name] - assert defining_pid.rank == sumpart.pid.rank - add_needed_pid(sumpart.pid, defining_pid) + defining_pid = name_to_computing_pid.get(input_name) + + if defining_pid is None: + defining_pid = name_to_receiving_pid.get(input_name) + + if defining_pid is None: + raise AssertionError( + f"name '{input_name}' in part {sumpart} not defined " + "via output or receive") + + if defining_pid == sumpart.pid: + # Yes, we look at our own sends. But we don't need to + # include an edge for them--it'll look like a cycle. + pass + else: + assert defining_pid.rank == sumpart.pid.rank + add_needed_pid(sumpart.pid, defining_pid) # }}} @@ -269,7 +370,7 @@ def verify_distributed_partition(mpi_communicator: mpi4py.MPI.Comm, # Do a topological sort to check for any cycles from pytools.graph import compute_topological_order, CycleError - from pytato.partition import PartitionInducedCycleError + from pytato.distributed.verify import PartitionInducedCycleError try: compute_topological_order(pid_to_needed_pids) except CycleError: diff --git a/pytato/partition.py b/pytato/partition.py deleted file mode 100644 index 1cfa5f5..0000000 --- a/pytato/partition.py +++ /dev/null @@ -1,474 +0,0 @@ -from __future__ import annotations - -__copyright__ = """ -Copyright (C) 2021 University of Illinois Board of Trustees -""" - -__license__ = """ -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - -from typing import (Any, Callable, Dict, Union, Set, List, Hashable, Tuple, TypeVar, - FrozenSet, Mapping, Optional, Type) -import attrs - -import logging -logger = logging.getLogger(__name__) - -from pytools import memoize_method -from pytato.transform import EdgeCachedMapper, CachedWalkMapper -from pytato.array import ( - Array, AbstractResultWithNamedArrays, Placeholder, - DictOfNamedArrays, make_placeholder, make_dict_of_named_arrays) - -from pytato.target import BoundProgram -from pymbolic.mapper.optimize import optimize_mapper - - -__doc__ = """ -Partitioning of graphs in :mod:`pytato` currently mainly serves to enable -:ref:`distributed computation `, i.e. sending and receiving data -as part of graph evaluation. - -However, as implemented, it is completely general and not specific to this use -case. Partitioning of expression graphs is based on a few assumptions: - -- We must be able to execute parts in any dependency-respecting order. -- Parts are compiled at partitioning time, so what inputs they take from memory - vs. what they compute is decided at that time. -- No part may depend on its own outputs as inputs. - (cf. :exc:`PartitionInducedCycleError`) - -.. autoclass:: GraphPart -.. autoclass:: GraphPartition -.. autoclass:: GraphPartitioner -.. autoexception:: PartitionInducedCycleError - -.. autofunction:: find_partition -.. autofunction:: execute_partition - -Internal stuff that is only here because the documentation tool wants it -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. class:: T - - A type variable for :class:`~pytato.array.AbstractResultWithNamedArrays`. -""" - - -ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] -T = TypeVar("T", bound=ArrayOrNames) -PartId = Hashable - - -# {{{ graph partitioner - -class GraphPartitioner(EdgeCachedMapper): - """Given a function *get_part_id*, produces subgraphs representing - the computation. Users should not use this class directly, but use - :meth:`find_partition` instead. - - .. automethod:: __init__ - .. automethod:: __call__ - .. automethod:: make_partition - """ - - def __init__(self, get_part_id: Callable[[ArrayOrNames], PartId]) -> None: - super().__init__() - - # Function to determine the part ID - self._get_part_id: Callable[[ArrayOrNames], PartId] = \ - get_part_id - - # Naming for newly created PlaceHolders at part edges - from pytools import UniqueNameGenerator - self.name_generator = UniqueNameGenerator(forced_prefix="_pt_part_ph_") - - # "edges" of the partitioned graph, maps an edge between two parts, - # represented by a tuple of part identifiers, to a set of placeholder - # names "conveying" information across the edge. - self.part_pair_to_edges: Dict[Tuple[PartId, PartId], - Set[str]] = {} - - self.var_name_to_result: Dict[str, Array] = {} - - self._seen_node_to_placeholder: Dict[ArrayOrNames, Placeholder] = {} - - # Reading the seen part IDs out of part_pair_to_edges is incorrect: - # e.g. if each part is self-contained, no edges would appear. Instead, - # we remember each part ID we see below, to guarantee that we don't - # miss any of them. - self.seen_part_ids: Set[PartId] = set() - - self.pid_to_user_input_names: Dict[PartId, Set[str]] = {} - - def get_part_id(self, expr: ArrayOrNames) -> PartId: - part_id = self._get_part_id(expr) - self.seen_part_ids.add(part_id) - return part_id - - def does_edge_cross_part_boundary(self, - node1: ArrayOrNames, node2: ArrayOrNames) -> bool: - return self.get_part_id(node1) != self.get_part_id(node2) - - def make_new_placeholder_name(self) -> str: - return self.name_generator() - - def add_inter_part_edge(self, target: ArrayOrNames, dependency: ArrayOrNames, - placeholder_name: str) -> None: - pid_target = self.get_part_id(target) - pid_dependency = self.get_part_id(dependency) - - self.part_pair_to_edges.setdefault( - (pid_target, pid_dependency), set()).add(placeholder_name) - - def handle_edge(self, expr: ArrayOrNames, child: ArrayOrNames) -> Any: - if self.does_edge_cross_part_boundary(expr, child): - try: - ph = self._seen_node_to_placeholder[child] - except KeyError: - ph_name = self.make_new_placeholder_name() - # If an edge crosses a part boundary, replace the - # depended-upon node (that nominally lives in the other part) - # with a Placeholder that lives in the current part. For each - # part, collect the placeholder names that it’s supposed to - # compute. - - if not isinstance(child, Array): - raise NotImplementedError("not currently supporting " - "DictOfNamedArrays in the middle of graph " - "partitioning") - - ph = make_placeholder(ph_name, - shape=child.shape, - dtype=child.dtype, - tags=child.tags, - axes=child.axes) - - # type-ignore-reason: mypy is right, types of self.rec are - # imprecise (TODO) - self.var_name_to_result[ph_name] = ( - self.rec(child)) # type: ignore[assignment] - - self._seen_node_to_placeholder[child] = ph - - assert ph.name - self.add_inter_part_edge(expr, child, ph.name) - return ph - - else: - return self.rec(child) - - def __call__(self, expr: T, *args: Any, **kwargs: Any) -> Any: - # Need to make sure the first node's part is 'seen' - self.get_part_id(expr) - - return super().__call__(expr, *args, **kwargs) - - def make_partition(self, outputs: DictOfNamedArrays) -> GraphPartition: - rewritten_outputs = { - name: self(expr) for name, expr in sorted(outputs._data.items())} - - pid_to_output_names: Dict[PartId, Set[str]] = { - pid: set() for pid in self.seen_part_ids} - pid_to_input_names: Dict[PartId, Set[str]] = { - pid: set() for pid in self.seen_part_ids} - - var_name_to_result = self.var_name_to_result.copy() - - for out_name, rewritten_output in sorted(rewritten_outputs.items()): - out_part_id = self._get_part_id(outputs._data[out_name]) - pid_to_output_names.setdefault(out_part_id, set()).add(out_name) - var_name_to_result[out_name] = rewritten_output - - # Mapping of nodes to their successors; used to compute the topological order - pid_to_needing_pids: Dict[PartId, Set[PartId]] = { - pid: set() for pid in self.seen_part_ids} - pid_to_needed_pids: Dict[PartId, Set[PartId]] = { - pid: set() for pid in self.seen_part_ids} - - for (pid_target, pid_dependency), var_names in \ - self.part_pair_to_edges.items(): - pid_to_needing_pids[pid_dependency].add(pid_target) - pid_to_needed_pids[pid_target].add(pid_dependency) - - for var_name in var_names: - pid_to_output_names[pid_dependency].add(var_name) - pid_to_input_names[pid_target].add(var_name) - - from pytools.graph import compute_topological_order, CycleError - try: - toposorted_part_ids = compute_topological_order( - pid_to_needing_pids, - lambda x: sorted(pid_to_output_names[x])) - except CycleError: - raise PartitionInducedCycleError - - return GraphPartition( - parts={ - pid: GraphPart( - pid=pid, - needed_pids=frozenset(pid_to_needed_pids[pid]), - user_input_names=frozenset( - self.pid_to_user_input_names.get(pid, set())), - partition_input_names=frozenset(pid_to_input_names[pid]), - output_names=frozenset(pid_to_output_names[pid]), - ) - for pid in self.seen_part_ids}, - var_name_to_result=var_name_to_result, - toposorted_part_ids=toposorted_part_ids) - - def map_placeholder(self, expr: Placeholder, *args: Any) -> Any: - pid = self.get_part_id(expr) - self.pid_to_user_input_names.setdefault(pid, set()).add(expr.name) - return super().map_placeholder(expr) - -# }}} - - -# {{{ graph partition - -@attrs.define(frozen=True, slots=False) -class GraphPart: - """ - .. attribute:: pid - - An identifier for this part of the graph. - - .. attribute:: needed_pids - - The IDs of parts that are required to be evaluated before this - part can be evaluated. - - .. attribute:: user_input_names - - A :class:`frozenset` of names representing input to the computational - graph, i.e. which were *not* introduced by partitioning. - - .. attribute:: partition_input_names - - A :class:`frozenset` of names of placeholders the part requires as - input from other parts in the partition. - - .. attribute:: output_names - - Names of placeholders this part provides as output. - - .. automethod:: all_input_names - """ - pid: PartId - needed_pids: FrozenSet[PartId] - user_input_names: FrozenSet[str] - partition_input_names: FrozenSet[str] - output_names: FrozenSet[str] - - @memoize_method - def all_input_names(self) -> FrozenSet[str]: - return self.user_input_names | self. partition_input_names - - -@attrs.define(frozen=True, slots=False) -class GraphPartition: - """Store information about a partitioning of an expression graph. - - .. attribute:: parts - - Mapping from part IDs to instances of :class:`GraphPart`. - - .. attribute:: var_name_to_result - - Mapping of placeholder names to the respective :class:`pytato.array.Array` - they represent. - - .. attribute:: toposorted_part_ids - - One possible topologically sorted ordering of part IDs that is - admissible under :attr:`GraphPart.needed_pids`. - - .. note:: - - This attribute could be recomputed for those dependencies. Since it - is computed as part of :func:`find_partition` anyway, it is - preserved here. - """ - parts: Mapping[PartId, GraphPart] - var_name_to_result: Mapping[str, Array] - toposorted_part_ids: List[PartId] - -# }}} - - -class PartitionInducedCycleError(Exception): - """Raised by :func:`find_partition` if the partitioning induced a - cycle in the graph of partitions. - """ - - -# {{{ find_partition - -def find_partition(outputs: DictOfNamedArrays, - part_func: Callable[[ArrayOrNames], PartId], - partitioner_class: Type[GraphPartitioner] = GraphPartitioner) ->\ - GraphPartition: - """Partitions the *expr* according to *part_func* and generates code for - each partition. Raises :exc:`PartitionInducedCycleError` if the partitioning - induces a cycle, e.g. for a graph like the following:: - - ┌───┐ - ┌──┤ A ├──┐ - │ └───┘ │ - │ ┌─▼─┐ - │ │ B │ - │ └─┬─┘ - │ ┌───┐ │ - └─►│ C │◄─┘ - └───┘ - - where ``A`` and ``C`` are in partition 1, and ``B`` is in partition 2. - - :param outputs: The outputs to partition. - :param part_func: A callable that returns an instance of - :class:`Hashable` for a node. - :param partitioner_class: A :class:`GraphPartitioner` to - guide the partitioning. - :returns: An instance of :class:`GraphPartition` that contains the partition. - """ - - result = partitioner_class(part_func).make_partition(outputs) - - # {{{ Check partitions and log statistics - - if __debug__: - _check_partition_disjointness(result) - - from pytato.analysis import get_num_nodes - num_nodes_per_part = [get_num_nodes(make_dict_of_named_arrays( - {x: result.var_name_to_result[x] for x in part.output_names})) - for part in result.parts.values()] - - logger.info(f"find_partition: Split {get_num_nodes(outputs)} nodes into " - f"{len(result.parts)} parts, with {num_nodes_per_part} nodes in each " - "partition.") - - # }}} - - return result - - -@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) -class _SeenNodesWalkMapper(CachedWalkMapper): - def __init__(self) -> None: - super().__init__() - self.seen_nodes: Set[ArrayOrNames] = set() - - # type-ignore-reason: dropped the extra `*args, **kwargs`. - def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] - return id(expr) - - # type-ignore-reason: dropped the extra `*args, **kwargs`. - def visit(self, expr: ArrayOrNames) -> bool: # type: ignore[override] - super().visit(expr) - self.seen_nodes.add(expr) - return True - - -def _check_partition_disjointness(partition: GraphPartition) -> None: - part_id_to_nodes: Dict[PartId, Set[ArrayOrNames]] = {} - - for part in partition.parts.values(): - mapper = _SeenNodesWalkMapper() - for out_name in part.output_names: - mapper(partition.var_name_to_result[out_name]) - - # FIXME This check won't do much unless we successfully visit - # all the nodes, but we're not currently checking that. - for my_node in mapper.seen_nodes: - for other_part_id, other_node_set in part_id_to_nodes.items(): - # Placeholders represent values computed in one partition - # and used in one or more other ones. As a result, the - # same placeholder may occur in more than one partition. - if not (isinstance(my_node, Placeholder) - or my_node not in other_node_set): - raise RuntimeError( - "Partitions not disjoint: " - f"{my_node.__class__.__name__} (id={hex(id(my_node))}) " - f"in both '{part.pid}' and '{other_part_id}'" - f"{part.output_names=} " - f"{partition.parts[other_part_id].output_names=} ") - - part_id_to_nodes[part.pid] = mapper.seen_nodes - -# }}} - - -# {{{ generate_code_for_partition - -def generate_code_for_partition(partition: GraphPartition) \ - -> Mapping[PartId, BoundProgram]: - """Return a mapping of partition identifiers to their - :class:`pytato.target.BoundProgram`.""" - from pytato import generate_loopy - part_id_to_prg = {} - - for part in sorted(partition.parts.values(), - key=lambda part_: sorted(part_.output_names)): - d = make_dict_of_named_arrays( - {var_name: partition.var_name_to_result[var_name] - for var_name in part.output_names - }) - part_id_to_prg[part.pid] = generate_loopy(d) - - return part_id_to_prg - -# }}} - - -# {{{ execute_partitions - -def execute_partition(partition: GraphPartition, prg_per_partition: - Dict[PartId, BoundProgram], queue: Any, - input_args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - """Executes a set of partitions on a :class:`pyopencl.CommandQueue`. - - :param parts: An instance of :class:`GraphPartition` representing the - partitioned code. - :param queue: An instance of :class:`pyopencl.CommandQueue` to execute the - code on. - :returns: A dictionary of variable names mapped to their values. - """ - if input_args is None: - input_args = {} - - context: Dict[str, Any] = input_args.copy() - - for pid in partition.toposorted_part_ids: - part = partition.parts[pid] - inputs = { - k: context[k] for k in part.all_input_names() - if k in context} - - _evt, result_dict = prg_per_partition[pid](queue=queue, **inputs) - context.update(result_dict) - - return context - -# }}} - - -# vim: foldmethod=marker diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 23898a3..ebcd306 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -28,7 +28,6 @@ THE SOFTWARE. import logging import numpy as np -from abc import abstractmethod from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, List, Mapping, Iterable, Tuple, Optional, Hashable) @@ -74,7 +73,6 @@ __doc__ = """ .. autoclass:: CachedWalkMapper .. autoclass:: TopoSortMapper .. autoclass:: CachedMapAndCopyMapper -.. autoclass:: EdgeCachedMapper .. autofunction:: copy_dict_of_named_arrays .. autofunction:: get_dependencies .. autofunction:: map_and_copy @@ -1570,179 +1568,6 @@ def tag_user_nodes( # }}} -# {{{ EdgeCachedMapper - -class EdgeCachedMapper(CachedMapper[ArrayOrNames]): - """ - Mapper class to execute a rewriting method (:meth:`handle_edge`) on each - edge in the graph. - - .. automethod:: handle_edge - """ - - @abstractmethod - def handle_edge(self, expr: ArrayOrNames, child: ArrayOrNames) -> Any: - pass - - def rec_idx_or_size_tuple(self, - expr: Array, - situp: Tuple[IndexOrShapeExpr, ...], - *args: Any) -> Tuple[IndexOrShapeExpr, ...]: - return tuple([ - self.handle_edge(expr, dim, *args) if isinstance(dim, Array) else dim - for dim in situp]) - - # {{{ map_xxx methods - - def map_named_array(self, expr: NamedArray, *args: Any) -> NamedArray: - return type(expr)( - self.handle_edge(expr, expr._container, *args), - name=expr.name, - axes=expr.axes, - tags=expr.tags) - - def map_index_lambda(self, expr: IndexLambda, *args: Any) -> IndexLambda: - return IndexLambda(expr=expr.expr, - shape=self.rec_idx_or_size_tuple(expr, expr.shape), - dtype=expr.dtype, - bindings={name: self.handle_edge(expr, child) - for name, child in sorted(expr.bindings.items())}, - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags) - - def map_einsum(self, expr: Einsum, *args: Any) -> Einsum: - return Einsum( - access_descriptors=expr.access_descriptors, - args=tuple(self.handle_edge(expr, arg, *args) - for arg in expr.args), - axes=expr.axes, - redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, - index_to_access_descr=expr.index_to_access_descr, - tags=expr.tags) - - def map_stack(self, expr: Stack, *args: Any) -> Stack: - return Stack( - arrays=tuple(self.handle_edge(expr, ary, *args) - for ary in expr.arrays), - axis=expr.axis, - axes=expr.axes, - tags=expr.tags) - - def map_concatenate(self, expr: Concatenate, *args: Any) -> Concatenate: - return Concatenate( - arrays=tuple(self.handle_edge(expr, ary, *args) - for ary in expr.arrays), - axis=expr.axis, - axes=expr.axes, - tags=expr.tags) - - def map_roll(self, expr: Roll, *args: Any) -> Roll: - return Roll(array=self.handle_edge(expr, expr.array, *args), - shift=expr.shift, - axis=expr.axis, - axes=expr.axes, - tags=expr.tags) - - def map_axis_permutation(self, expr: AxisPermutation, *args: Any) \ - -> AxisPermutation: - return AxisPermutation( - array=self.handle_edge(expr, expr.array, *args), - axis_permutation=expr.axis_permutation, - axes=expr.axes, - tags=expr.tags) - - def map_reshape(self, expr: Reshape, *args: Any) -> Reshape: - return Reshape( - array=self.handle_edge(expr, expr.array, *args), - newshape=self.rec_idx_or_size_tuple(expr, expr.newshape, *args), - order=expr.order, - axes=expr.axes, - tags=expr.tags) - - def map_basic_index(self, expr: BasicIndex, *args: Any) -> BasicIndex: - return BasicIndex( - array=self.handle_edge(expr, expr.array, *args), - indices=tuple(self.handle_edge(expr, idx, *args) - if isinstance(idx, Array) else idx - for idx in expr.indices), - axes=expr.axes, - tags=expr.tags) - - def map_contiguous_advanced_index(self, - expr: AdvancedIndexInContiguousAxes, *args: Any) \ - -> AdvancedIndexInContiguousAxes: - return AdvancedIndexInContiguousAxes( - array=self.handle_edge(expr, expr.array, *args), - indices=tuple(self.handle_edge(expr, idx, *args) - if isinstance(idx, Array) else idx - for idx in expr.indices), - axes=expr.axes, - tags=expr.tags) - - def map_non_contiguous_advanced_index(self, - expr: AdvancedIndexInNoncontiguousAxes, *args: Any) \ - -> AdvancedIndexInNoncontiguousAxes: - return AdvancedIndexInNoncontiguousAxes( - array=self.handle_edge(expr, expr.array, *args), - indices=tuple(self.handle_edge(expr, idx, *args) - if isinstance(idx, Array) else idx - for idx in expr.indices), - axes=expr.axes, - tags=expr.tags) - - def map_data_wrapper(self, expr: DataWrapper, *args: Any) -> DataWrapper: - return DataWrapper( - data=expr.data, - shape=self.rec_idx_or_size_tuple(expr, expr.shape, *args), - axes=expr.axes, - tags=expr.tags) - - def map_placeholder(self, expr: Placeholder, *args: Any) -> Placeholder: - assert expr.name - - return Placeholder(name=expr.name, - shape=self.rec_idx_or_size_tuple(expr, expr.shape, *args), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags) - - def map_size_param(self, expr: SizeParam, *args: Any) -> SizeParam: - assert expr.name - return SizeParam(expr.name, axes=expr.axes, tags=expr.tags) - - def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: - return LoopyCall( - translation_unit=expr.translation_unit, - entrypoint=expr.entrypoint, - bindings={ - name: self.handle_edge(expr, child) - if isinstance(child, Array) else child - for name, child in sorted(expr.bindings.items())}, - tags=expr.tags, - ) - - def map_distributed_send_ref_holder( - self, expr: DistributedSendRefHolder, *args: Any) -> \ - DistributedSendRefHolder: - return DistributedSendRefHolder( - send=self.handle_edge(expr, expr.send.data), - passthrough_data=self.handle_edge(expr, expr.passthrough_data), - tags=expr.tags - ) - - def map_distributed_recv(self, expr: DistributedRecv, *args: Any) \ - -> Any: - return DistributedRecv( - src_rank=expr.src_rank, comm_tag=expr.comm_tag, - shape=self.rec_idx_or_size_tuple(expr, expr.shape, *args), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes) - - # }}} - -# }}} - - # {{{ deduplicate_data_wrappers def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable: diff --git a/pytato/visualization.py b/pytato/visualization.py index 4bdb103..803ceb0 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -46,8 +46,8 @@ from pytato.array import ( from pytato.codegen import normalize_outputs from pytato.transform import CachedMapper, ArrayOrNames -from pytato.partition import GraphPartition, PartId -from pytato.distributed.partition import DistributedGraphPart +from pytato.distributed.partition import ( + DistributedGraphPartition, DistributedGraphPart, PartId) if TYPE_CHECKING: from pytato.distributed.nodes import DistributedSendRefHolder @@ -359,11 +359,11 @@ def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: return emit.get() -def get_dot_graph_from_partition(partition: GraphPartition) -> str: +def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: r"""Return a string in the `dot `_ language depicting the graph of the partitioned computation of *partition*. - :arg partition: Outputs of :func:`~pytato.partition.find_partition`. + :arg partition: Outputs of :func:`~pytato.find_distributed_partition`. """ # Maps each partition to a dict of its arrays with the node info part_id_to_node_info: Dict[Hashable, Dict[ArrayOrNames, DotNodeInfo]] = {} @@ -371,7 +371,7 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: for part in partition.parts.values(): mapper = ArrayToDotNodeInfoMapper() for out_name in part.output_names: - mapper(partition.var_name_to_result[out_name]) + mapper(partition.name_to_output[out_name]) part_id_to_node_info[part.pid] = mapper.nodes @@ -411,7 +411,7 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: if isinstance(part, DistributedGraphPart): part_dist_recv_var_name_to_node_id = {} for name, recv in ( - part.input_name_to_recv_node.items()): + part.name_to_recv_node.items()): node_id = id_gen("recv") _emit_array(emit, "DistributedRecv", { "shape": stringify_shape(recv.shape), @@ -471,7 +471,7 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: break assert computing_pid is not None tgt = part_id_to_array_to_id[computing_pid][ - partition.var_name_to_result[array.name]] + partition.name_to_output[array.name]] emit(f"{tgt} -> {array_to_id[array]} [style=dashed]") emitted_placeholders.add(array) @@ -500,17 +500,18 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: deferred_send_edges = [] if isinstance(part, DistributedGraphPart): - for name, send in ( - part.output_name_to_send_node.items()): - node_id = id_gen("send") - _emit_array(emit, "DistributedSend", { - "dest_rank": str(send.dest_rank), - "comm_tag": str(send.comm_tag), - }, node_id) - - deferred_send_edges.append( - f"{array_to_id[send.data]} -> {node_id}" - f'[style=dotted, label="{dot_escape(name)}"]') + for name, sends in ( + part.name_to_send_nodes.items()): + for send in sends: + node_id = id_gen("send") + _emit_array(emit, "DistributedSend", { + "dest_rank": str(send.dest_rank), + "comm_tag": str(send.comm_tag), + }, node_id) + + deferred_send_edges.append( + f"{array_to_id[send.data]} -> {node_id}" + f'[style=dotted, label="{dot_escape(name)}"]') # }}} @@ -530,20 +531,21 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: return emit.get() -def show_dot_graph(result: Union[str, Array, DictOfNamedArrays, GraphPartition], +def show_dot_graph(result: Union[str, Array, DictOfNamedArrays, + DistributedGraphPartition], **kwargs: Any) -> None: """Show a graph representing the computation of *result* in a browser. :arg result: Outputs of the computation (cf. :func:`pytato.generate_loopy`) or the output of :func:`get_dot_graph`, - or the output of :func:`~pytato.partition.find_partition`. + or the output of :func:`~pytato.find_distributed_partition`. :arg kwargs: Passed on to :func:`pytools.graphviz.show_dot` unmodified. """ dot_code: str if isinstance(result, str): dot_code = result - elif isinstance(result, GraphPartition): + elif isinstance(result, DistributedGraphPartition): dot_code = get_dot_graph_from_partition(result) else: dot_code = get_dot_graph(result) diff --git a/test/test_codegen.py b/test/test_codegen.py index f0e4469..8749065 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1386,73 +1386,6 @@ def test_random_dag_against_numpy(ctx_factory): assert np.allclose(pt_result["result"], ref_result) -def test_partitioner(ctx_factory): - ctx = ctx_factory() - queue = cl.CommandQueue(ctx) - - from testlib import RandomDAGContext, make_random_dag - - axis_len = 5 - - ntests = 50 - ncycles = 0 - for i in range(ntests): - print(i) - seed = 120 + i - rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), - axis_len=axis_len, use_numpy=False) - rdagc_np = RandomDAGContext(np.random.default_rng(seed=seed), - axis_len=axis_len, use_numpy=True) - - ref_result = make_random_dag(rdagc_np) - - from pytato.transform import materialize_with_mpms - dict_named_arys = materialize_with_mpms(pt.make_dict_of_named_arrays( - {"result": make_random_dag(rdagc_pt)})) - - from dataclasses import dataclass - from pytato.transform import TopoSortMapper - from pytato.partition import (find_partition, - execute_partition, generate_code_for_partition, - PartitionInducedCycleError) - - @dataclass(frozen=True, eq=True) - class MyPartitionId(): - num: int - - def get_partition_id(topo_list, expr) -> MyPartitionId: - return topo_list.index(expr) // 3 - - tm = TopoSortMapper() - tm(dict_named_arys) - - from functools import partial - part_func = partial(get_partition_id, tm.topological_order) - - try: - partition = find_partition(dict_named_arys, part_func) - except PartitionInducedCycleError: - print("CYCLE!") - # FIXME *shrug* nothing preventing that currently - ncycles += 1 - continue - - from pytato.visualization import get_dot_graph_from_partition - get_dot_graph_from_partition(partition) - - # Execute the partitioned code - prg_per_part = generate_code_for_partition(partition) - - context = execute_partition(partition, prg_per_part, queue) - - pt_part_res, = [context[k] for k in dict_named_arys] - - np.testing.assert_allclose(pt_part_res, ref_result) - - # Assert that at least 2/3 of our tests did not get skipped because of cycles - assert ncycles < ntests // 3 - - def test_assume_non_negative_indirect_address(ctx_factory): from numpy.random import default_rng from pytato.scalar_expr import WalkMapper diff --git a/test/test_distributed.py b/test/test_distributed.py index 232c408..97a7be5 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -21,6 +21,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import pytest +from pytools.graph import CycleError from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) import pyopencl as cl @@ -97,7 +99,7 @@ def _do_test_distributed_execution_basic(ctx_factory): # Find the partition outputs = pt.make_dict_of_named_arrays({"out": y}) - distributed_parts = pt.find_distributed_partition(outputs) + distributed_parts = pt.find_distributed_partition(comm, outputs) prg_per_partition = pt.generate_code_for_partition(distributed_parts) # Execute the distributed partition @@ -175,7 +177,7 @@ def _do_test_distributed_execution_random_dag(ctx_factory): {"result": make_random_dag(rdagc_comm)}) x_comm = pt.transform.materialize_with_mpms(pt_dag) - distributed_partition = pt.find_distributed_partition(x_comm) + distributed_partition = pt.find_distributed_partition(comm, x_comm) pt.verify_distributed_partition(comm, distributed_partition) # Transform symbolic tags into numeric ones for MPI @@ -184,11 +186,6 @@ def _do_test_distributed_execution_random_dag(ctx_factory): distributed_partition, base_tag=comm_tag) - # Regression check for https://github.com/inducer/pytato/issues/307 - from pytato.distributed.partition import PartIDTag - for ary in distributed_partition.var_name_to_result.values(): - assert not ary.tags_of_type(PartIDTag) - prg_per_partition = pt.generate_code_for_partition(distributed_partition) context = pt.execute_distributed_partition( @@ -242,7 +239,7 @@ def _test_dag_with_no_comm_nodes_inner(ctx_factory): # }}} - parts = pt.find_distributed_partition(dag) + parts = pt.find_distributed_partition(comm, dag) assert len(parts.parts) == 1 prg_per_partition = pt.generate_code_for_partition(parts) out_dict = pt.execute_distributed_partition( @@ -260,60 +257,48 @@ def test_dag_with_no_comm_nodes(): # {{{ test deterministic partitioning -def _check_deterministic_partition(dag, ref_partition, - iproc, results): - partition = pt.find_distributed_partition(dag) - are_equal = int(partition == ref_partition) - print(iproc, are_equal) - results[iproc] = are_equal +def _gather_random_dist_partitions(ctx_factory): + import mpi4py.MPI as MPI + comm = MPI.COMM_WORLD -def test_deterministic_partitioning(): - import multiprocessing as mp - import os + seed = int(os.environ["PYTATO_DAG_SEED"]) from testlib import get_random_pt_dag_with_send_recv_nodes + dag = get_random_pt_dag_with_send_recv_nodes( + seed, rank=comm.rank, size=comm.size, + convert_dws_to_placeholders=True) - original_hash_seed = os.environ.pop("PYTHONHASHSEED", None) - - nprocs = 4 - - mp_ctx = mp.get_context("spawn") - - ntests = 10 - for i in range(ntests): - seed = 120 + i - results = mp_ctx.Array("i", (0, ) * nprocs) - print(f"Step {i} {seed}") - - ref_dag = get_random_pt_dag_with_send_recv_nodes( - seed, rank=0, size=7, - convert_dws_to_placeholders=True) + my_partition = pt.find_distributed_partition(comm, dag) - ref_partition = pt.find_distributed_partition(ref_dag) + all_partitions = comm.gather(my_partition) - # {{{ spawn nprocs-processes and verify they all compare equally + from pickle import dump + if comm.rank == 0: + with open(os.environ["PYTATO_PARTITIONS_DUMP_FN"], "wb") as outf: + dump(all_partitions, outf) - procs = [mp_ctx.Process(target=_check_deterministic_partition, - args=(ref_dag, - ref_partition, - iproc, results)) - for iproc in range(nprocs)] - for iproc, proc in enumerate(procs): - # See - # https://replit.com/@KaushikKulkarn1/spawningprocswithhashseedv2?v=1#main.py - os.environ["PYTHONHASHSEED"] = str(iproc) - proc.start() +@pytest.mark.parametrize("seed", list(range(10))) +def test_deterministic_partitioning(seed): + import os + from pickle import load + from pytools import is_single_valued - for proc in procs: - proc.join() + partitions_across_seeds = [] + partitions_dump_fn = f"tmp-partitions-{os.getpid()}.pkl" - if original_hash_seed is not None: - os.environ["PYTHONHASHSEED"] = original_hash_seed + for hashseed in [234, 241, 9222, 5]: + run_test_with_mpi(2, _gather_random_dist_partitions, extra_env_vars={ + "PYTATO_DAG_SEED": str(seed), + "PYTHONHASHSEED": str(hashseed), + "PYTATO_PARTITIONS_DUMP_FN": partitions_dump_fn, + }) - assert set(results[:]) == {1} + with open(partitions_dump_fn, "rb") as inf: + partitions_across_seeds.append(load(inf)) + os.unlink(partitions_dump_fn) - # }}} + assert is_single_valued(partitions_across_seeds) # }}} @@ -330,7 +315,6 @@ def _do_verify_distributed_partition(ctx_factory): import pytest from pytato.distributed.verify import (DuplicateSendError, DuplicateRecvError, MissingSendError, MissingRecvError) - from pytato.partition import PartitionInducedCycleError rank = comm.Get_rank() size = comm.Get_size() @@ -341,16 +325,13 @@ def _do_verify_distributed_partition(ctx_factory): src_rank=(rank+1) % size, comm_tag=42, shape=(4, 4), dtype=int) outputs = pt.make_dict_of_named_arrays({"out": y}) - distributed_parts = pt.find_distributed_partition(outputs) - - if rank == 0: - with pytest.raises(MissingSendError): - pt.verify_distributed_partition(comm, distributed_parts) - else: - pt.verify_distributed_partition(comm, distributed_parts) + with pytest.raises(MissingSendError): + pt.find_distributed_partition(comm, outputs) # }}} + comm.barrier() + # {{{ test unmatched send x = pt.make_placeholder("x", (4, 4), int) @@ -358,36 +339,30 @@ def _do_verify_distributed_partition(ctx_factory): dest_rank=(rank-1) % size, comm_tag=42, stapled_to=x) outputs = pt.make_dict_of_named_arrays({"out": send}) - distributed_parts = pt.find_distributed_partition(outputs) - - if rank == 0: - with pytest.raises(MissingRecvError): - pt.verify_distributed_partition(comm, distributed_parts) - else: - pt.verify_distributed_partition(comm, distributed_parts) + with pytest.raises(MissingRecvError): + pt.find_distributed_partition(comm, outputs) # }}} + comm.barrier() + # {{{ test duplicate recv recv2 = pt.make_distributed_recv( - src_rank=(rank+1) % size, comm_tag=42, shape=(4, 4), dtype=int) + src_rank=(rank+1) % size, comm_tag=42, shape=(4, 4), dtype=float) send = pt.staple_distributed_send(recv2, dest_rank=(rank-1) % size, comm_tag=42, stapled_to=pt.make_distributed_recv( src_rank=(rank+1) % size, comm_tag=42, shape=(4, 4), dtype=int)) outputs = pt.make_dict_of_named_arrays({"out": x+send}) - distributed_parts = pt.find_distributed_partition(outputs) - - if rank == 0: - with pytest.raises(DuplicateRecvError): - pt.verify_distributed_partition(comm, distributed_parts) - else: - pt.verify_distributed_partition(comm, distributed_parts) + with pytest.raises(DuplicateRecvError): + pt.find_distributed_partition(comm, outputs) # }}} + comm.barrier() + # {{{ test duplicate send send = pt.staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=42, @@ -398,16 +373,13 @@ def _do_verify_distributed_partition(ctx_factory): dest_rank=(rank-1) % size, comm_tag=42, stapled_to=x) outputs = pt.make_dict_of_named_arrays({"out": send+send2}) - distributed_parts = pt.find_distributed_partition(outputs) - - if rank == 0: - with pytest.raises(DuplicateSendError): - pt.verify_distributed_partition(comm, distributed_parts) - else: - pt.verify_distributed_partition(comm, distributed_parts) + with pytest.raises(DuplicateSendError): + pt.find_distributed_partition(comm, outputs) # }}} + comm.barrier() + # {{{ test cycle recv = pt.make_distributed_recv( @@ -423,13 +395,10 @@ def _do_verify_distributed_partition(ctx_factory): stapled_to=recv) outputs = pt.make_dict_of_named_arrays({"out": send+send2}) - distributed_parts = pt.find_distributed_partition(outputs) - if rank == 0: - with pytest.raises(PartitionInducedCycleError): - pt.verify_distributed_partition(comm, distributed_parts) - else: - pt.verify_distributed_partition(comm, distributed_parts) + print(f"BEGIN {comm.rank}") + with pytest.raises(CycleError): + pt.find_distributed_partition(comm, outputs) # }}} -- GitLab