diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a1ae71440139ae1876df0dec525531bb3eb6d1c3..f098a94321e95a327ae1a6486c6e5c66d0cbd5dc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -79,6 +79,8 @@ jobs: build_py_project_in_conda_env run_examples + mpirun -n 2 python distributed.py + docs: name: Documentation runs-on: ubuntu-latest diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 7c34e94ae8a63f1374259a9f84203e1b4edb4785..6459d4c06fcff2b84c489e7f64395cffe86f309f 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -11,3 +11,4 @@ dependencies: - pyopencl - islpy - sphinx-autodoc-typehints +- mpi4py diff --git a/doc/dag.rst b/doc/dag.rst index b6e7e7d24c61fffd4c9f56c31f3046bb45f61b30..f4874f1d741e0c65e8d5dbd1e36abc5cb80d53f9 100644 --- a/doc/dag.rst +++ b/doc/dag.rst @@ -28,6 +28,11 @@ Partitioning Array Expression Graphs .. automodule:: pytato.partition +Support for Distributed-Memory/Message Passing +============================================== + +.. automodule:: pytato.distributed + Utilities and Diagnostics ========================= diff --git a/doc/design.rst b/doc/design.rst index 71984d4c734ed508af3dd975560593e5b0e2f5fd..f402a933fad48133d8265617a6f9f639f3c996a2 100644 --- a/doc/design.rst +++ b/doc/design.rst @@ -146,6 +146,9 @@ Reserved Identifiers transport across parts of a partitioned DAG (cf. :func:`~pytato.partition.find_partition`). + - ``_pt_dist``: Used to automatically generate identifiers for + placeholders at distributed communication boundaries. + - Identifiers used in index lambdas are also reserved. These include: - Identifiers matching the regular expression ``_[0-9]+``. They are used diff --git a/examples/distributed.py b/examples/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..1f12839f01f74862dd08ee0f33e28164f1fb16bb --- /dev/null +++ b/examples/distributed.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python + +from mpi4py import MPI # pylint: disable=import-error +comm = MPI.COMM_WORLD + +import pytato as pt +import pyopencl as cl +import numpy as np + +from pytato import (find_distributed_partition, generate_code_for_partition, + number_distributed_tags, + execute_distributed_partition, + staple_distributed_send, make_distributed_recv) + + +def main(): + rank = comm.Get_rank() + size = comm.Get_size() + rng = np.random.default_rng() + + x_in = rng.integers(100, size=(4, 4)) + x = pt.make_data_wrapper(x_in) + + mytag = (main, "x") + halo = staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=mytag, + stapled_to=make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=mytag, shape=(4, 4), dtype=int)) + + y = x+halo + + # Find the partition + outputs = pt.DictOfNamedArrays({"out": y}) + distributed_parts = find_distributed_partition(outputs) + distributed_parts, _ = number_distributed_tags( + comm, distributed_parts, base_tag=42) + prg_per_partition = generate_code_for_partition(distributed_parts) + + if 0: + from pytato.visualization import show_dot_graph + show_dot_graph(distributed_parts) + + # Sanity check + from pytato.visualization import get_dot_graph_from_partition + get_dot_graph_from_partition(distributed_parts) + + # Execute the distributed partition + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + context = execute_distributed_partition(distributed_parts, prg_per_partition, + queue, comm) + + final_res = context["out"].get(queue) + + ref_res = comm.bcast(final_res) + + np.testing.assert_allclose(ref_res, final_res) + + if rank == 0: + print("Distributed test succeeded.") + + +if __name__ == "__main__": + main() diff --git a/pytato/__init__.py b/pytato/__init__.py index 97444b6d03849281fc94f31d0878cd3709c1b8be..6582dbe55ad8a49d681db4ec5600878c0d4be029 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -32,7 +32,6 @@ from pytato.array import ( AdvancedIndexInContiguousAxes, AdvancedIndexInNoncontiguousAxes, SizeParam, Axis, - make_dict_of_named_arrays, make_placeholder, make_size_param, make_data_wrapper, einsum, @@ -70,6 +69,15 @@ from pytato.visualization import (get_dot_graph, show_dot_graph, import pytato.analysis as analysis import pytato.tags as tags import pytato.transform as transform +from pytato.distributed import (make_distributed_send, make_distributed_recv, + DistributedRecv, DistributedSend, + DistributedSendRefHolder, + staple_distributed_send, + find_distributed_partition, + number_distributed_tags, + execute_distributed_partition) + +from pytato.partition import generate_code_for_partition __all__ = ( "Array", "AbstractResultWithNamedArrays", "DictOfNamedArrays", @@ -111,6 +119,15 @@ __all__ = ( "broadcast_to", + "make_distributed_recv", "make_distributed_send", "DistributedRecv", + "DistributedSend", "staple_distributed_send", "DistributedSendRefHolder", + + "find_distributed_partition", + "number_distributed_tags", + "execute_distributed_partition", + + "generate_code_for_partition", + # sub-modules "analysis", "tags", "transform", diff --git a/pytato/distributed.py b/pytato/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..6624d6bcc11d7e25f3c82425fd31cd70832515db --- /dev/null +++ b/pytato/distributed.py @@ -0,0 +1,675 @@ +from __future__ import annotations + +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import (Any, Dict, Hashable, Tuple, Optional, Set, # noqa: F401 + List, FrozenSet, Callable, cast, Mapping) # Mapping required by sphinx + +from dataclasses import dataclass + +from pytools import UniqueNameGenerator +from pytools.tag import Taggable, TagsType +from pytato.array import (Array, _SuppliedShapeAndDtypeMixin, + DictOfNamedArrays, ShapeType, + Placeholder, make_placeholder, + _get_default_axes, AxesT) +from pytato.transform import ArrayOrNames, CopyMapper +from pytato.partition import GraphPart, GraphPartition, PartId, GraphPartitioner +from pytato.target import BoundProgram + +import numpy as np + +__doc__ = r""" +Distributed communication +------------------------- + +.. currentmodule:: pytato +.. autoclass:: DistributedSend +.. autoclass:: DistributedSendRefHolder +.. autoclass:: DistributedRecv + +.. autofunction:: make_distributed_send +.. autofunction:: staple_distributed_send +.. autofunction:: make_distributed_recv + +.. currentmodule:: pytato.distributed + +.. autoclass:: DistributedGraphPart +.. autoclass:: DistributedGraphPartition + +.. currentmodule:: pytato + +.. autofunction:: find_distributed_partition +.. autofunction:: number_distributed_tags +.. autofunction:: execute_distributed_partition + +Internal stuff that is only here because the documentation tool wants it +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. class:: CommTagType + + A type representing a communication tag. + +.. class:: TagsType + + A :class:`frozenset` of :class:`pytools.tag.Tag`\ s. + +.. class:: ShapeType + + A type representing a shape. + +.. class:: AxesT + + A :class:`tuple` of :class:`Axis` objects. +""" + + +# {{{ Distributed node types + +CommTagType = Hashable + + +class DistributedSend(Taggable): + """Class representing a distributed send operation. + + .. attribute:: data + + The :class:`~pytato.Array` to be sent. + + .. attribute:: dest_rank + + An :class:`int`. The rank to which :attr:`data` is to be sent. + + .. attribute:: comm_tag + + A hashable, picklable object to serve as a 'tag' for the communication. + Only a :class:`DistributedRecv` with the same tag will be able to + receive the data being sent here. + """ + + def __init__(self, data: Array, dest_rank: int, comm_tag: CommTagType, + tags: TagsType = frozenset()) -> None: + super().__init__(tags=tags) + self.data = data + self.dest_rank = dest_rank + self.comm_tag = comm_tag + + def __hash__(self) -> int: + return ( + hash(self.__class__) + ^ hash(self.data) + ^ hash(self.dest_rank) + ^ hash(self.comm_tag) + ^ hash(self.tags) + ) + + def __eq__(self, other: Any) -> bool: + return ( + self.__class__ is other.__class__ + and self.data == other.data + and self.dest_rank == other.dest_rank + and self.comm_tag == other.comm_tag + and self.tags == other.tags) + + def copy(self, **kwargs: Any) -> DistributedSend: + data: Optional[Array] = kwargs.get("data") + dest_rank: Optional[int] = kwargs.get("dest_rank") + comm_tag: Optional[CommTagType] = kwargs.get("comm_tag") + tags: Optional[TagsType] = kwargs.get("tags") + return type(self)( + data=data or self.data, + dest_rank=dest_rank if dest_rank is not None else self.dest_rank, + comm_tag=comm_tag if comm_tag is not None else self.comm_tag, + tags=tags if tags is not None else self.tags) + + +class DistributedSendRefHolder(Array): + """A node acting as an identity on :attr:`passthrough_data` while also holding + a reference to a :class:`DistributedSend` in :attr:`send`. Since + :mod:`pytato` represents data flow, and since no data flows 'out' + of a :class:`DistributedSend`, no node in all of :mod:`pytato` has + a good reason to hold a reference to a send node, since there is + no useful result of a send (at least of of an :class:`~pytato.Array` type). + + This is where this node type comes in. Its value is the same as that of + :attr:`passthrough_data`, *and* it holds a reference to the send node. + + .. note:: + + This all seems a wee bit inelegant, but nobody who has written + or reviewed this code so far had a better idea. If you do, please speak up! + + .. attribute:: send + + The :class:`DistributedSend` to which a reference is to be held. + + .. attribute:: passthrough_data + + A :class:`~pytato.Array`. The value of this node. + """ + + _mapper_method = "map_distributed_send_ref_holder" + _fields = Array._fields + ("passthrough_data", "send",) + + def __init__(self, send: DistributedSend, passthrough_data: Array, + tags: TagsType = frozenset()) -> None: + super().__init__(axes=passthrough_data.axes, tags=tags) + self.send = send + self.passthrough_data = passthrough_data + + @property + def shape(self) -> ShapeType: + return self.passthrough_data.shape + + @property + def dtype(self) -> np.dtype[Any]: + return self.passthrough_data.dtype + + +class DistributedRecv(_SuppliedShapeAndDtypeMixin, Array): + """Class representing a distributed receive operation. + + .. attribute:: src_rank + + An :class:`int`. The rank from which an array is to be received. + + .. attribute:: comm_tag + + A hashable, picklable object to serve as a 'tag' for the communication. + Only a :class:`DistributedRecv` with the same tag will be able to + receive the data being sent here. + + .. attribute:: shape + .. attribute:: dtype + """ + + _fields = Array._fields + ("shape", "dtype", "src_rank", "comm_tag") + _mapper_method = "map_distributed_recv" + + def __init__(self, src_rank: int, comm_tag: CommTagType, + shape: ShapeType, dtype: Any, + tags: Optional[TagsType] = frozenset(), + axes: Optional[AxesT] = None) -> None: + + if not axes: + axes = _get_default_axes(len(shape)) + super().__init__(shape=shape, dtype=dtype, tags=tags, + axes=axes) + self.src_rank = src_rank + self.comm_tag = comm_tag + + +def make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType, + send_tags: TagsType = frozenset()) -> \ + DistributedSend: + return DistributedSend(sent_data, dest_rank, comm_tag, send_tags) + + +def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType, + stapled_to: Array, *, + send_tags: TagsType = frozenset(), + ref_holder_tags: TagsType = frozenset()) -> \ + DistributedSendRefHolder: + return DistributedSendRefHolder( + DistributedSend(sent_data, dest_rank, comm_tag, send_tags), + stapled_to, tags=ref_holder_tags) + + +def make_distributed_recv(src_rank: int, comm_tag: CommTagType, + shape: ShapeType, dtype: Any, + tags: TagsType = frozenset()) \ + -> DistributedRecv: + dtype = np.dtype(dtype) + return DistributedRecv(src_rank, comm_tag, shape, dtype, tags) + +# }}} + + +# {{{ distributed info collection + +@dataclass(frozen=True) +class DistributedGraphPart(GraphPart): + """For one graph partition, record send/receive information for input/ + output names. + + .. attribute:: input_name_to_recv_node + .. attribute:: output_name_to_send_node + .. attribute:: distributed_sends + """ + input_name_to_recv_node: Dict[str, DistributedRecv] + output_name_to_send_node: Dict[str, DistributedSend] + distributed_sends: List[DistributedSend] + + +@dataclass(frozen=True) +class DistributedGraphPartition(GraphPartition): + """Store information about distributed graph partitions. This + has the same attributes as :class:`~pytato.partition.GraphPartition`, + however :attr:`~pytato.partition.GraphPartition.parts` now maps to + instances of :class:`DistributedGraphPart`. + """ + parts: Dict[PartId, DistributedGraphPart] + + +class _DistributedCommReplacer(CopyMapper): + """Mapper to process a DAG for realization of :class:`DistributedSend` + and :class:`DistributedRecv` outside of normal code generation. + + - Replaces :class:`DistributedRecv` with :class`~pytato.Placeholder` + so that received data can be externally supplied, making a note + in :attr:`input_name_to_recv_node`. + + - Eliminates :class:`DistributedSendRefHolder` and + :class:`DistributedSend` from the DAG, making a note of data + to be send in :attr:`output_name_to_send_node`. + """ + + def __init__(self, dist_name_generator: UniqueNameGenerator) -> None: + super().__init__() + + self.name_generator = dist_name_generator + + self.input_name_to_recv_node: Dict[str, DistributedRecv] = {} + self.output_name_to_send_node: Dict[str, DistributedSend] = {} + + def map_distributed_recv(self, expr: DistributedRecv) -> Placeholder: + # no children, no need to recurse + + new_name = self.name_generator() + self.input_name_to_recv_node[new_name] = expr + return make_placeholder(new_name, self.rec_idx_or_size_tuple(expr.shape), + expr.dtype, tags=expr.tags) + + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder) -> Array: + raise ValueError("DistributedSendRefHolder should not occur in partitioned " + "graphs") + + def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: + new_send = DistributedSend( + data=self.rec(expr.data), + dest_rank=expr.dest_rank, + comm_tag=expr.comm_tag, + tags=expr.tags) + + new_name = self.name_generator() + self.output_name_to_send_node[new_name] = new_send + + return new_send + + +def _gather_distributed_comm_info(partition: GraphPartition, + pid_to_distributed_sends: Dict[PartId, List[DistributedSend]]) -> \ + DistributedGraphPartition: + var_name_to_result = {} + parts: Dict[PartId, DistributedGraphPart] = {} + + dist_name_generator = UniqueNameGenerator(forced_prefix="_pt_dist_") + + for part in partition.parts.values(): + comm_replacer = _DistributedCommReplacer(dist_name_generator) + part_results = { + var_name: comm_replacer(partition.var_name_to_result[var_name]) + for var_name in part.output_names} + + dist_sends = [ + comm_replacer.map_distributed_send(send) + for send in pid_to_distributed_sends.get(part.pid, [])] + + part_results.update({ + name: send_node.data + for name, send_node in + comm_replacer.output_name_to_send_node.items()}) + + parts[part.pid] = DistributedGraphPart( + pid=part.pid, + needed_pids=part.needed_pids, + user_input_names=part.user_input_names, + partition_input_names=(part.partition_input_names + | frozenset(comm_replacer.input_name_to_recv_node)), + output_names=(part.output_names + | frozenset(comm_replacer.output_name_to_send_node)), + distributed_sends=dist_sends, + + input_name_to_recv_node=comm_replacer.input_name_to_recv_node, + output_name_to_send_node=comm_replacer.output_name_to_send_node) + + for name, val in part_results.items(): + assert name not in var_name_to_result + var_name_to_result[name] = val + + result = DistributedGraphPartition( + parts=parts, + var_name_to_result=var_name_to_result, + toposorted_part_ids=partition.toposorted_part_ids) + + if __debug__: + # Check disjointness again since we replaced a few nodes. + from pytato.partition import _check_partition_disjointness + _check_partition_disjointness(result) + + return result + +# }}} + + +# {{{ find distributed partition + +@dataclass(frozen=True, eq=True) +class DistributedPartitionId(): + fed_sends: object + feeding_recvs: object + + +class _DistributedGraphPartitioner(GraphPartitioner): + + def __init__(self, get_part_id: Callable[[ArrayOrNames], PartId]) -> None: + super().__init__(get_part_id) + self.pid_to_dist_sends: Dict[PartId, List[DistributedSend]] = {} + + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, *args: Any) -> Any: + send_part_id = self.get_part_id(expr.send.data) + + from pytato.distributed import DistributedSend + self.pid_to_dist_sends.setdefault(send_part_id, []).append( + DistributedSend( + data=self.rec(expr.send.data), + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag, + tags=expr.send.tags)) + + return self.rec(expr.passthrough_data) + + def make_partition(self, outputs: DictOfNamedArrays) \ + -> DistributedGraphPartition: + + partition = super().make_partition(outputs) + return _gather_distributed_comm_info(partition, self.pid_to_dist_sends) + + +def find_distributed_partition( + outputs: DictOfNamedArrays) -> DistributedGraphPartition: + """Finds a partitioning in a distributed context.""" + + from pytato.transform import (UsersCollector, TopoSortMapper, + reverse_graph, tag_user_nodes) + + gdm = UsersCollector() + gdm(outputs) + + graph = gdm.node_to_users + + # type-ignore-reason: + # 'graph' also includes DistributedSend nodes, which are not Arrays + rev_graph = reverse_graph(graph) # type: ignore[arg-type] + + # FIXME: Inefficient... too many traversals + node_to_feeding_recvs: Dict[ArrayOrNames, Set[ArrayOrNames]] = {} + for node in graph: + node_to_feeding_recvs.setdefault(node, set()) + if isinstance(node, DistributedRecv): + tag_user_nodes(graph, tag=node, # type: ignore[arg-type] + starting_point=node, + node_to_tags=node_to_feeding_recvs) + + node_to_fed_sends: Dict[ArrayOrNames, Set[ArrayOrNames]] = {} + for node in rev_graph: + node_to_fed_sends.setdefault(node, set()) + if isinstance(node, DistributedSend): + tag_user_nodes(rev_graph, tag=node, starting_point=node, + node_to_tags=node_to_fed_sends) + + def get_part_id(expr: ArrayOrNames) -> DistributedPartitionId: + return DistributedPartitionId(frozenset(node_to_fed_sends[expr]), + frozenset(node_to_feeding_recvs[expr])) + + # {{{ Sanity checks + + if __debug__: + for node, _ in node_to_feeding_recvs.items(): + for n in node_to_feeding_recvs[node]: + assert(isinstance(n, DistributedRecv)) + + for node, _ in node_to_fed_sends.items(): + for n in node_to_fed_sends[node]: + assert(isinstance(n, DistributedSend)) + + tm = TopoSortMapper() + tm(outputs) + + for node in tm.topological_order: + get_part_id(node) + + # }}} + + from pytato.partition import find_partition + return cast(DistributedGraphPartition, + find_partition(outputs, get_part_id, _DistributedGraphPartitioner)) + +# }}} + + +# {{{ construct tag numbering + +def number_distributed_tags( + mpi_communicator: Any, + partition: DistributedGraphPartition, + base_tag: int) -> Tuple[DistributedGraphPartition, int]: + """Return a new :class:`~pytato.distributed.DistributedGraphPartition` + in which symbolic tags are replaced by unique integer tags, created by + counting upward from *base_tag*. + + :returns: a tuple ``(partition, next_tag)``, where *partition* is the new + :class:`~pytato.distributed.DistributedGraphPartition` with numerical + tags, and *next_tag* is the lowest integer tag above *base_tag* that + was not used. + + .. note:: + + This is a potentially heavyweight MPI-collective operation on + *mpi_communicator*. + """ + tags = frozenset({ + recv.comm_tag + for part in partition.parts.values() + for name, recv in part.input_name_to_recv_node.items() + } | { + send.comm_tag + for part in partition.parts.values() + for name, send in part.output_name_to_send_node.items()}) + + from mpi4py import MPI + + def set_union( + set_a: FrozenSet[Any], set_b: FrozenSet[Any], + mpi_data_type: MPI.Datatype) -> FrozenSet[str]: + assert mpi_data_type is None + assert isinstance(set_a, frozenset) + assert isinstance(set_b, frozenset) + + return set_a | set_b + + root_rank = 0 + + set_union_mpi_op = MPI.Op.Create( + # type ignore reason: mpi4py misdeclares op functions as returning + # None. + set_union, # type: ignore[arg-type] + commute=True) + try: + all_tags = mpi_communicator.reduce( + tags, set_union_mpi_op, root=root_rank) + finally: + set_union_mpi_op.Free() + + if mpi_communicator.rank == root_rank: + sym_tag_to_int_tag = {} + next_tag = base_tag + for sym_tag in all_tags: + sym_tag_to_int_tag[sym_tag] = next_tag + next_tag += 1 + + mpi_communicator.bcast((sym_tag_to_int_tag, next_tag), root=root_rank) + else: + sym_tag_to_int_tag, next_tag = mpi_communicator.bcast(None, root=root_rank) + + from dataclasses import replace + return DistributedGraphPartition( + parts={ + pid: replace(part, + input_name_to_recv_node={ + name: recv.copy(comm_tag=sym_tag_to_int_tag[recv.comm_tag]) + for name, recv in part.input_name_to_recv_node.items()}, + output_name_to_send_node={ + name: send.copy(comm_tag=sym_tag_to_int_tag[send.comm_tag]) + for name, send in part.output_name_to_send_node.items()}, + ) + for pid, part in partition.parts.items() + }, + var_name_to_result=partition.var_name_to_result, + toposorted_part_ids=partition.toposorted_part_ids), next_tag + +# }}} + + +# {{{ distributed execute + +def _post_receive(mpi_communicator: Any, + recv: DistributedRecv) -> Tuple[Any, np.ndarray[Any, Any]]: + # FIXME: recv.shape might be parametric, evaluate + buf = np.empty(recv.shape, dtype=recv.dtype) + + return mpi_communicator.Irecv( + buf=buf, source=recv.src_rank, tag=recv.comm_tag), buf + + +def _mpi_send(mpi_communicator: Any, send_node: DistributedSend, + data: np.ndarray[Any, Any]) -> Any: + # Must use-non-blocking send, as blocking send may wait for a corresponding + # receive to be posted (but if sending to self, this may only occur later). + return mpi_communicator.Isend( + data, dest=send_node.dest_rank, tag=send_node.comm_tag) + + +def execute_distributed_partition( + partition: DistributedGraphPartition, prg_per_partition: + Dict[Hashable, BoundProgram], + queue: Any, mpi_communicator: Any, + input_args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + + if input_args is None: + input_args = {} + + from mpi4py import MPI + + if len(partition.parts) != 1: + recv_names_tup, recv_requests_tup, recv_buffers_tup = zip(*[ + (name,) + _post_receive(mpi_communicator, recv) + for part in partition.parts.values() + for name, recv in part.input_name_to_recv_node.items()]) + recv_names = list(recv_names_tup) + recv_requests = list(recv_requests_tup) + recv_buffers = list(recv_buffers_tup) + del recv_names_tup + del recv_requests_tup + del recv_buffers_tup + else: + # Only a single partition, no recv requests exist + recv_names = [] + recv_requests = [] + recv_buffers = [] + + context: Dict[str, Any] = input_args.copy() + + pids_to_execute = set(partition.parts) + pids_executed = set() + recv_names_completed = set() + send_requests = [] + + def exec_ready_part(part: DistributedGraphPart) -> None: + inputs = {k: context[k] for k in part.all_input_names()} + + _evt, result_dict = prg_per_partition[part.pid](queue, **inputs) + + context.update(result_dict) + + for name, send_node in part.output_name_to_send_node.items(): + # FIXME: pytato shouldn't depend on pyopencl + if isinstance(context[name], np.ndarray): + data = context[name] + else: + data = context[name].get(queue) + send_requests.append(_mpi_send(mpi_communicator, send_node, data)) + + pids_executed.add(part.pid) + pids_to_execute.remove(part.pid) + + def wait_for_some_recvs() -> None: + complete_recv_indices = MPI.Request.Waitsome(recv_requests) + + # Waitsome is allowed to return None + if not complete_recv_indices: + complete_recv_indices = [] + + # reverse to preserve indices + for idx in sorted(complete_recv_indices, reverse=True): + name = recv_names.pop(idx) + recv_requests.pop(idx) + buf = recv_buffers.pop(idx) + + # FIXME: pytato shouldn't depend on pyopencl + import pyopencl as cl + context[name] = cl.array.to_device(queue, buf) + recv_names_completed.add(name) + + # FIXME: This keeps all variables alive that are used to get data into + # and out of partitions. Probably not what we want long-term. + + # {{{ main loop + + while pids_to_execute: + ready_pids = {pid + for pid in pids_to_execute + # FIXME: Only O(n**2) altogether. Nobody is going to notice, right? + if partition.parts[pid].needed_pids <= pids_executed + and (set(partition.parts[pid].input_name_to_recv_node) + <= recv_names_completed)} + for pid in ready_pids: + exec_ready_part(partition.parts[pid]) + + if not ready_pids: + wait_for_some_recvs() + + # }}} + + for send_req in send_requests: + send_req.Wait() + + return context + +# }}} + +# vim: foldmethod=marker diff --git a/pytato/equality.py b/pytato/equality.py index 0bfdb6242a704c51ae9056a087ac86854b045f38..984d1f7124ede7dfdf22a2a6fd443f21f9675073 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -34,6 +34,7 @@ from pytato.array import (AdvancedIndexInContiguousAxes, if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult + from pytato.distributed import DistributedRecv, DistributedSendRefHolder __doc__ = """ .. autoclass:: EqualityComparer @@ -245,6 +246,26 @@ class EqualityComparer: and all(self.rec(expr1._data[name], expr2._data[name]) for name in expr1._data)) + def map_distributed_send_ref_holder( + self, expr1: DistributedSendRefHolder, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.rec(expr1.send.data, expr2.send.data) + and self.rec(expr1.passthrough_data, expr2.passthrough_data) + and expr1.send.dest_rank == expr2.send.dest_rank + and expr1.send.comm_tag == expr2.send.comm_tag + and expr1.send.tags == expr2.send.tags + and expr1.tags == expr2.tags + ) + + def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.src_rank == expr2.src_rank + and expr1.comm_tag == expr2.comm_tag + and expr1.shape == expr2.shape + and expr1.dtype == expr2.dtype + and expr1.tags == expr2.tags + ) + # }}} # vim: fdm=marker diff --git a/pytato/partition.py b/pytato/partition.py index 72012a533863a84803e84d66b224ef8157378b54..223dc7ef53ed81a6506d435787f528900dffc220 100644 --- a/pytato/partition.py +++ b/pytato/partition.py @@ -25,10 +25,11 @@ THE SOFTWARE. """ from typing import (Any, Callable, Dict, Union, Set, List, Hashable, Tuple, TypeVar, - FrozenSet, Mapping) + FrozenSet, Mapping, Optional, Type) from dataclasses import dataclass +from pytools import memoize_method from pytato.transform import EdgeCachedMapper, CachedWalkMapper from pytato.array import ( Array, AbstractResultWithNamedArrays, Placeholder, @@ -40,10 +41,19 @@ from pytato.target import BoundProgram __doc__ = """ .. autoclass:: GraphPart .. autoclass:: GraphPartition +.. autoclass:: GraphPartitioner .. autoexception:: PartitionInducedCycleError .. autofunction:: find_partition .. autofunction:: execute_partition + + +Internal stuff that is only here because the documentation tool wants it +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. class:: T + + A type variable for :class:`~pytato.array.AbstractResultWithNamedArrays`. """ @@ -54,10 +64,14 @@ PartId = Hashable # {{{ graph partitioner -class _GraphPartitioner(EdgeCachedMapper): +class GraphPartitioner(EdgeCachedMapper): """Given a function *get_part_id*, produces subgraphs representing the computation. Users should not use this class directly, but use :meth:`find_partition` instead. + + .. automethod:: __init__ + .. automethod:: __call__ + .. automethod:: make_partition """ def __init__(self, get_part_id: Callable[[ArrayOrNames], PartId]) -> None: @@ -87,6 +101,8 @@ class _GraphPartitioner(EdgeCachedMapper): # miss any of them. self.seen_part_ids: Set[PartId] = set() + self.pid_to_user_input_names: Dict[PartId, Set[str]] = {} + def get_part_id(self, expr: ArrayOrNames) -> PartId: part_id = self._get_part_id(expr) self.seen_part_ids.add(part_id) @@ -146,6 +162,62 @@ class _GraphPartitioner(EdgeCachedMapper): return super().__call__(expr, *args, **kwargs) + def make_partition(self, outputs: DictOfNamedArrays) -> GraphPartition: + rewritten_outputs = { + name: self(expr) for name, expr in outputs._data.items()} + + pid_to_output_names: Dict[PartId, Set[str]] = { + pid: set() for pid in self.seen_part_ids} + pid_to_input_names: Dict[PartId, Set[str]] = { + pid: set() for pid in self.seen_part_ids} + + var_name_to_result = self.var_name_to_result.copy() + + for out_name, rewritten_output in rewritten_outputs.items(): + out_part_id = self._get_part_id(outputs._data[out_name]) + pid_to_output_names.setdefault(out_part_id, set()).add(out_name) + var_name_to_result[out_name] = rewritten_output + + # Mapping of nodes to their successors; used to compute the topological order + pid_to_needing_pids: Dict[PartId, Set[PartId]] = { + pid: set() for pid in self.seen_part_ids} + pid_to_needed_pids: Dict[PartId, Set[PartId]] = { + pid: set() for pid in self.seen_part_ids} + + for (pid_target, pid_dependency), var_names in \ + self.part_pair_to_edges.items(): + pid_to_needing_pids[pid_dependency].add(pid_target) + pid_to_needed_pids[pid_target].add(pid_dependency) + + for var_name in var_names: + pid_to_output_names[pid_dependency].add(var_name) + pid_to_input_names[pid_target].add(var_name) + + from pytools.graph import compute_topological_order, CycleError + try: + toposorted_part_ids = compute_topological_order(pid_to_needing_pids) + except CycleError: + raise PartitionInducedCycleError + + return GraphPartition( + parts={ + pid: GraphPart( + pid=pid, + needed_pids=frozenset(pid_to_needed_pids[pid]), + user_input_names=frozenset( + self.pid_to_user_input_names.get(pid, set())), + partition_input_names=frozenset(pid_to_input_names[pid]), + output_names=frozenset(pid_to_output_names[pid]), + ) + for pid in self.seen_part_ids}, + var_name_to_result=var_name_to_result, + toposorted_part_ids=toposorted_part_ids) + + def map_placeholder(self, expr: Placeholder, *args: Any) -> Any: + pid = self.get_part_id(expr) + self.pid_to_user_input_names.setdefault(pid, set()).add(expr.name) + return super().map_placeholder(expr) + # }}} @@ -163,19 +235,38 @@ class GraphPart: The IDs of parts that are required to be evaluated before this part can be evaluated. - .. attribute:: input_names + .. attribute:: user_input_names + + A :class:`dict` mapping names to :class:`~pytato.array.Placeholder` + instances that represent input to the computational graph, i.e. were + *not* introduced by partitioning. - Names of placeholders the part requires as input. + .. attribute:: partition_input_names + + Names of placeholders the part requires as input from other parts + in the partition. .. attribute:: output_names Names of placeholders this part provides as output. + + .. attribute:: distributed_sends + + List of :class:`~pytato.distributed.DistributedSend` instances whose + data are in this part. + + .. automethod:: all_input_names """ pid: PartId needed_pids: FrozenSet[PartId] - input_names: FrozenSet[str] + user_input_names: FrozenSet[str] + partition_input_names: FrozenSet[str] output_names: FrozenSet[str] + @memoize_method + def all_input_names(self) -> FrozenSet[str]: + return self.user_input_names | self. partition_input_names + @dataclass(frozen=True) class GraphPartition: @@ -214,10 +305,11 @@ class PartitionInducedCycleError(Exception): """ -# {{{ find_partitions +# {{{ find_partition def find_partition(outputs: DictOfNamedArrays, - part_func: Callable[[ArrayOrNames], PartId]) ->\ + part_func: Callable[[ArrayOrNames], PartId], + partitioner_class: Type[GraphPartitioner] = GraphPartitioner) ->\ GraphPartition: """Partitions the *expr* according to *part_func* and generates code for each partition. Raises :exc:`PartitionInducedCycleError` if the partitioning @@ -235,58 +327,15 @@ def find_partition(outputs: DictOfNamedArrays, where ``A`` and ``C`` are in partition 1, and ``B`` is in partition 2. - :param expr: The expression to partition. + :param outputs: The outputs to partition. :param part_func: A callable that returns an instance of :class:`Hashable` for a node. + :param partitioner_class: A :class:`GraphPartitioner` to + guide the partitioning. :returns: An instance of :class:`GraphPartition` that contains the partition. """ - gp = _GraphPartitioner(part_func) - rewritten_outputs = {name: gp(expr) for name, expr in outputs._data.items()} - - pid_to_output_names: Dict[PartId, Set[str]] = { - pid: set() for pid in gp.seen_part_ids} - pid_to_input_names: Dict[PartId, Set[str]] = { - pid: set() for pid in gp.seen_part_ids} - - var_name_to_result = gp.var_name_to_result.copy() - - for out_name, rewritten_output in rewritten_outputs.items(): - out_part_id = part_func(outputs._data[out_name]) - pid_to_output_names.setdefault(out_part_id, set()).add(out_name) - var_name_to_result[out_name] = rewritten_output - - # Mapping of nodes to their successors; used to compute the topological order - pid_to_needing_pids: Dict[PartId, Set[PartId]] = { - pid: set() for pid in gp.seen_part_ids} - pid_to_needed_pids: Dict[PartId, Set[PartId]] = { - pid: set() for pid in gp.seen_part_ids} - - for (pid_target, pid_dependency), var_names in \ - gp.part_pair_to_edges.items(): - pid_to_needing_pids[pid_dependency].add(pid_target) - pid_to_needed_pids[pid_target].add(pid_dependency) - - for var_name in var_names: - pid_to_output_names[pid_dependency].add(var_name) - pid_to_input_names[pid_target].add(var_name) - - from pytools.graph import compute_topological_order, CycleError - try: - toposorted_part_ids = compute_topological_order(pid_to_needing_pids) - except CycleError: - raise PartitionInducedCycleError - - result = GraphPartition( - parts={ - pid: GraphPart( - pid=pid, - needed_pids=frozenset(pid_to_needed_pids[pid]), - input_names=frozenset(pid_to_input_names[pid]), - output_names=frozenset(pid_to_output_names[pid])) - for pid in gp.seen_part_ids}, - var_name_to_result=var_name_to_result, - toposorted_part_ids=toposorted_part_ids) + result = partitioner_class(part_func).make_partition(outputs) if __debug__: _check_partition_disjointness(result) @@ -320,12 +369,14 @@ def _check_partition_disjointness(partition: GraphPartition) -> None: # Placeholders represent values computed in one partition # and used in one or more other ones. As a result, the # same placeholder may occur in more than one partition. - assert ( - isinstance(my_node, Placeholder) - or my_node not in other_node_set), ( - "partitions not disjoint: " - f"{my_node.__class__.__name__} (id={id(my_node)}) " - f"in both '{part.pid}' and '{other_part_id}'") + if not(isinstance(my_node, Placeholder) + or my_node not in other_node_set): + raise RuntimeError( + "Partitions not disjoint: " + f"{my_node.__class__.__name__} (id={hex(id(my_node))}) " + f"in both '{part.pid}' and '{other_part_id}'" + f"{part.output_names=} " + f"{partition.parts[other_part_id].output_names=} ") part_id_to_nodes[part.pid] = mapper.seen_nodes @@ -339,16 +390,16 @@ def generate_code_for_partition(partition: GraphPartition) \ """Return a mapping of partition identifiers to their :class:`pytato.target.BoundProgram`.""" from pytato import generate_loopy - prg_per_partition = {} + part_id_to_prg = {} for part in partition.parts.values(): d = DictOfNamedArrays( {var_name: partition.var_name_to_result[var_name] for var_name in part.output_names }) - prg_per_partition[part.pid] = generate_loopy(d) + part_id_to_prg[part.pid] = generate_loopy(d) - return prg_per_partition + return part_id_to_prg # }}} @@ -356,7 +407,8 @@ def generate_code_for_partition(partition: GraphPartition) \ # {{{ execute_partitions def execute_partition(partition: GraphPartition, prg_per_partition: - Dict[PartId, BoundProgram], queue: Any) -> Dict[str, Any]: + Dict[PartId, BoundProgram], queue: Any, + input_args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """Executes a set of partitions on a :class:`pyopencl.CommandQueue`. :param parts: An instance of :class:`GraphPartition` representing the @@ -365,11 +417,15 @@ def execute_partition(partition: GraphPartition, prg_per_partition: code on. :returns: A dictionary of variable names mapped to their values. """ - context: Dict[str, Any] = {} + if input_args is None: + input_args = {} + + context: Dict[str, Any] = input_args.copy() + for pid in partition.toposorted_part_ids: part = partition.parts[pid] inputs = { - k: context[k] for k in part.input_names + k: context[k] for k in part.all_input_names() if k in context} _evt, result_dict = prg_per_partition[pid](queue=queue, **inputs) diff --git a/pytato/stringifier.py b/pytato/stringifier.py index e3f76b01512a3bab80a0d744be7aa81c66a6b3b4..b2fd5bf03502eb5ddeef25ecebc4820b80b68358 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -126,6 +126,8 @@ class Reprifier(Mapper): map_non_contiguous_advanced_index = _map_generic_array map_reshape = _map_generic_array map_einsum = _map_generic_array + map_distributed_recv = _map_generic_array + map_distributed_send_ref_holder = _map_generic_array def map_data_wrapper(self, expr: DataWrapper, depth: int) -> str: if depth > self.truncation_depth: diff --git a/pytato/transform.py b/pytato/transform.py index 863e927a6bd09571b05c5573ec5264e36532775a..bb531d11046b4a2395098831b72beab95603172a 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -2,6 +2,8 @@ from __future__ import annotations __copyright__ = """ Copyright (C) 2020 Matt Wala +Copyright (C) 2020-21 Kaushik Kulkarni +Copyright (C) 2020-21 University of Illinois Board of Trustees """ __license__ = """ @@ -24,10 +26,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, - List, Mapping, Iterable, Optional, Tuple) -from pyrsistent.typing import PMap as PMapT + List, Mapping, Iterable, Tuple, Optional, TYPE_CHECKING) from pytato.array import ( Array, IndexLambda, Placeholder, Stack, Roll, @@ -36,10 +37,13 @@ from pytato.array import ( IndexRemappingBase, Einsum, InputArgumentBase, BasicIndex, AdvancedIndexInContiguousAxes, AdvancedIndexInNoncontiguousAxes, IndexBase) + from pytato.loopy import LoopyCall, LoopyCallResult from dataclasses import dataclass from pytato.tags import ImplStored -from pyrsistent import pmap + +if TYPE_CHECKING: + from pytato.distributed import DistributedSendRefHolder, DistributedRecv T = TypeVar("T", Array, AbstractResultWithNamedArrays) CombineT = TypeVar("CombineT") # used in CombineMapper @@ -74,6 +78,7 @@ Dict representation of DAGs .. autoclass:: UsersCollector .. autofunction:: reverse_graph +.. autofunction:: tag_user_nodes .. autofunction:: rec_get_user_nodes Internal stuff that is only here because the documentation tool wants it @@ -306,6 +311,23 @@ class CopyMapper(CachedMapper[ArrayOrNames]): axes=expr.axes, tags=expr.tags) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder) -> Array: + from pytato.distributed import DistributedSend, DistributedSendRefHolder + return DistributedSendRefHolder( + DistributedSend( + data=self.rec(expr.send.data), + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag), + self.rec(expr.passthrough_data)) + + def map_distributed_recv(self, expr: DistributedRecv) -> Array: + from pytato.distributed import DistributedRecv + return DistributedRecv( + src_rank=expr.src_rank, comm_tag=expr.comm_tag, + shape=self.rec_idx_or_size_tuple(expr.shape), + dtype=expr.dtype, tags=expr.tags) + # }}} @@ -400,6 +422,16 @@ class CombineMapper(Mapper, Generic[CombineT]): def map_loopy_call_result(self, expr: LoopyCallResult) -> CombineT: return self.rec(expr._container) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder) -> CombineT: + return self.combine( + self.rec(expr.send.data), + self.rec(expr.passthrough_data), + ) + + def map_distributed_recv(self, expr: DistributedRecv) -> CombineT: + return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) + # }}} @@ -459,6 +491,14 @@ class DependencyMapper(CombineMapper[R]): def map_loopy_call_result(self, expr: LoopyCallResult) -> R: return self.combine(frozenset([expr]), super().map_loopy_call_result(expr)) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder) -> R: + return self.combine( + frozenset([expr]), super().map_distributed_send_ref_holder(expr)) + + def map_distributed_recv(self, expr: DistributedRecv) -> R: + return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) + # }}} @@ -648,6 +688,24 @@ class WalkMapper(Mapper): self.post_visit(expr) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, *args: Any) -> None: + if not self.visit(expr): + return + + self.rec(expr.send.data) + self.rec(expr.passthrough_data) + + self.post_visit(expr) + + def map_distributed_recv(self, expr: DistributedRecv, *args: Any) -> None: + if not self.visit(expr): + return + + self.rec_idx_or_size_tuple(expr.shape) + + self.post_visit(expr) + def map_named_array(self, expr: NamedArray) -> None: if not self.visit(expr): return @@ -1031,7 +1089,9 @@ class UsersCollector(CachedMapper[ArrayOrNames]): def __init__(self) -> None: super().__init__() - self.node_to_users: Dict[ArrayOrNames, Set[ArrayOrNames]] = {} + from pytato.distributed import DistributedSend + self.node_to_users: Dict[ArrayOrNames, + Set[Union[DistributedSend, ArrayOrNames]]] = {} def __call__(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any: # Root node has no predecessor @@ -1129,16 +1189,25 @@ class UsersCollector(CachedMapper[ArrayOrNames]): self.node_to_users.setdefault(child, set()).add(expr) self.rec(child) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, *args: Any) -> None: + self.node_to_users.setdefault(expr.passthrough_data, set()).add(expr) + self.rec(expr.passthrough_data) + self.node_to_users.setdefault(expr.send.data, set()).add(expr.send) + self.rec(expr.send.data) + + def map_distributed_recv(self, expr: DistributedRecv, *args: Any) -> None: + self.rec_idx_or_size_tuple(expr, expr.shape) + -def get_users(expr: ArrayOrNames) -> PMapT[ArrayOrNames, - FrozenSet[ArrayOrNames]]: +def get_users(expr: ArrayOrNames) -> Dict[ArrayOrNames, + Set[ArrayOrNames]]: """ Returns a mapping from node in *expr* to its direct users. """ user_collector = UsersCollector() user_collector(expr) - return pmap({node: frozenset(users) - for node, users in user_collector.node_to_users.items()}) + return user_collector.node_to_users # type: ignore # }}} @@ -1168,7 +1237,7 @@ def reverse_graph(graph: Dict[ArrayOrNames, Set[ArrayOrNames]]) \ def _recursively_get_all_users( - direct_users: Mapping[ArrayOrNames, FrozenSet[ArrayOrNames]], + direct_users: Mapping[ArrayOrNames, Set[ArrayOrNames]], node: ArrayOrNames) -> FrozenSet[ArrayOrNames]: result = set() queue = list(direct_users.get(node, set())) @@ -1201,8 +1270,8 @@ def rec_get_user_nodes(expr: ArrayOrNames, return _recursively_get_all_users(users, node) -def tag_child_nodes( - graph: Mapping[ArrayOrNames, FrozenSet[ArrayOrNames]], +def tag_user_nodes( + graph: Mapping[ArrayOrNames, Set[ArrayOrNames]], tag: Any, starting_point: ArrayOrNames, node_to_tags: Optional[Dict[ArrayOrNames, Set[ArrayOrNames]]] = None @@ -1214,12 +1283,11 @@ def tag_child_nodes( use case for this function is the graph in :attr:`UsersCollector.node_to_users`. :param tag: The value to tag the nodes with. - :param starting_point: An optional starting point in *graph*. + :param starting_point: A starting point in *graph*. :param node_to_tags: The resulting mapping of nodes to tags. - :returns: the updated value of *node_to_tags*. """ from warnings import warn - warn("tag_child_nodes is set for deprecation in June, 2022", + warn("tag_user_nodes is set for deprecation in June, 2022", DeprecationWarning) if node_to_tags is None: @@ -1237,7 +1305,7 @@ def tag_child_nodes( # {{{ EdgeCachedMapper -class EdgeCachedMapper(CachedMapper[ArrayOrNames], ABC): +class EdgeCachedMapper(CachedMapper[ArrayOrNames]): """ Mapper class to execute a rewriting method (:meth:`handle_edge`) on each edge in the graph. @@ -1384,6 +1452,24 @@ class EdgeCachedMapper(CachedMapper[ArrayOrNames], ABC): for name, child in expr.bindings.items()}, ) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, *args: Any) -> \ + DistributedSendRefHolder: + from pytato.distributed import DistributedSendRefHolder + return DistributedSendRefHolder( + send=self.handle_edge(expr, expr.send.data), + passthrough_data=self.handle_edge(expr, expr.passthrough_data), + tags=expr.tags + ) + + def map_distributed_recv(self, expr: DistributedRecv, *args: Any) \ + -> Any: + from pytato.distributed import DistributedRecv + return DistributedRecv( + src_rank=expr.src_rank, comm_tag=expr.comm_tag, + shape=self.rec_idx_or_size_tuple(expr, expr.shape, *args), + dtype=expr.dtype, tags=expr.tags, axes=expr.axes) + # }}} # }}} diff --git a/pytato/visualization.py b/pytato/visualization.py index cb7c48c995f6666114a29162f6d531202058947c..d1395860803e145222164cfb60d197192dbeb135 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -29,7 +29,9 @@ THE SOFTWARE. import contextlib import dataclasses import html -from typing import Callable, Dict, Union, Iterator, List, Mapping, Hashable + +from typing import (TYPE_CHECKING, Callable, Dict, Union, Iterator, List, + Mapping, Hashable) from pytools import UniqueNameGenerator from pytools.codegen import CodeGenerator as CodeGeneratorBase @@ -44,6 +46,10 @@ from pytato.codegen import normalize_outputs from pytato.transform import CachedMapper, ArrayOrNames from pytato.partition import GraphPartition +from pytato.distributed import DistributedGraphPart + +if TYPE_CHECKING: + from pytato.distributed import DistributedSendRefHolder __doc__ = """ @@ -80,7 +86,7 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" -class ArrayToDotNodeInfoMapper(CachedMapper[Array]): +class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): def __init__(self) -> None: super().__init__() self.nodes: Dict[ArrayOrNames, DotNodeInfo] = {} @@ -111,8 +117,7 @@ class ArrayToDotNodeInfoMapper(CachedMapper[Array]): info.edges[field] = attr elif isinstance(attr, AbstractResultWithNamedArrays): - # type-ignore-reason: incompatible with superclass - self.rec(attr) # type: ignore[arg-type] + self.rec(attr) info.edges[field] = attr elif isinstance(attr, tuple): @@ -189,6 +194,19 @@ class ArrayToDotNodeInfoMapper(CachedMapper[Array]): entrypoint=expr.entrypoint), edges=edges) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder) -> None: + + info = self.get_common_dot_info(expr) + + self.rec(expr.passthrough_data) + info.edges["passthrough"] = expr.passthrough_data + + self.rec(expr.send.data) + info.edges["sent"] = expr.send.data + + self.nodes[expr] = info + def dot_escape(s: str) -> str: # "\" and HTML are significant in graphviz. @@ -205,22 +223,22 @@ class DotEmitter(CodeGeneratorBase): self("}") -def _emit_array(emit: DotEmitter, info: DotNodeInfo, id: str, - color: str = "white") -> None: +def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str], + dot_node_id: str, color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' rows = ['%s' - % (td_attrib, dot_escape(info.title))] + % (td_attrib, dot_escape(title))] - for name, field in info.fields.items(): + for name, field in fields.items(): field_content = dot_escape(field).replace("\n", "
") rows.append( f"{dot_escape(name)}:" f"{field_content}" ) table = "\n%s
" % (table_attrib, "".join(rows)) - emit("%s [label=<%s> style=filled fillcolor=%s]" % (id, table, color)) + emit("%s [label=<%s> style=filled fillcolor=%s]" % (dot_node_id, table, color)) def _emit_name_cluster(emit: DotEmitter, names: Mapping[str, ArrayOrNames], @@ -282,11 +300,13 @@ def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: with emit.block("subgraph cluster_Inputs"): emit('label="Inputs"') for array in input_arrays: - _emit_array(emit, nodes[array], array_to_id[array]) + _emit_array(emit, + nodes[array].title, nodes[array].fields, array_to_id[array]) # Emit non-inputs. for array in internal_arrays: - _emit_array(emit, nodes[array], array_to_id[array]) + _emit_array(emit, + nodes[array].title, nodes[array].fields, array_to_id[array]) # Emit edges. for array, node in nodes.items(): @@ -334,6 +354,26 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: # Second pass: emit the graph. for part in partition.parts.values(): + # {{{ emit receives nodes if distributed + + if isinstance(part, DistributedGraphPart): + part_dist_recv_var_name_to_node_id = {} + for name, recv in ( + part.input_name_to_recv_node.items()): + node_id = id_gen("recv") + _emit_array(emit, "Recv", { + "shape": stringify_shape(recv.shape), + "dtype": str(recv.dtype), + "src_rank": str(recv.src_rank), + "comm_tag": str(recv.comm_tag), + }, node_id) + + part_dist_recv_var_name_to_node_id[name] = node_id + else: + part_dist_recv_var_name_to_node_id = {} + + # }}} + part_node_to_info = part_id_to_node_info[part.pid] input_arrays: List[Array] = [] internal_arrays: List[ArrayOrNames] = [] @@ -355,13 +395,27 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: # Non-Placeholders are emitted *inside* their subgraphs below. if isinstance(array, Placeholder): if array not in emitted_placeholders: - _emit_array(emit, part_node_to_info[array], + _emit_array(emit, + part_node_to_info[array].title, + part_node_to_info[array].fields, array_to_id[array], "deepskyblue") - emitted_placeholders.add(array) # Emit cross-partition edges - tgt = array_to_id[partition.var_name_to_result[array.name]] - emit(f"{tgt} -> {array_to_id[array]} [style=dashed]") + if array.name in part_dist_recv_var_name_to_node_id: + tgt = part_dist_recv_var_name_to_node_id[array.name] + emit(f"{tgt} -> {array_to_id[array]} [style=dotted]") + emitted_placeholders.add(array) + elif array.name in part.user_input_names: + # These are placeholders for external input. They + # are cleanly associated with a single partition + # and thus emitted below. + pass + else: + # placeholder for a value from a different partition + tgt = array_to_id[ + partition.var_name_to_result[array.name]] + emit(f"{tgt} -> {array_to_id[array]} [style=dashed]") + emitted_placeholders.add(array) # }}} @@ -370,13 +424,42 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: emit(f'label="{part.pid}"') for array in input_arrays: - if not isinstance(array, Placeholder): - _emit_array(emit, part_node_to_info[array], + if (not isinstance(array, Placeholder) + or array.name in part.user_input_names): + _emit_array(emit, + part_node_to_info[array].title, + part_node_to_info[array].fields, array_to_id[array], "deepskyblue") # Emit internal nodes for array in internal_arrays: - _emit_array(emit, part_node_to_info[array], array_to_id[array]) + _emit_array(emit, + part_node_to_info[array].title, + part_node_to_info[array].fields, + array_to_id[array]) + + # {{{ emit send nodes if distributed + + deferred_send_edges = [] + if isinstance(part, DistributedGraphPart): + for name, send in ( + part.output_name_to_send_node.items()): + node_id = id_gen("send") + _emit_array(emit, "Send", { + "dest_rank": str(send.dest_rank), + "comm_tag": str(send.comm_tag), + }, node_id) + + deferred_send_edges.append( + f"{array_to_id[send.data]} -> {node_id}" + f'[style=dotted, label="{dot_escape(name)}"]') + + # }}} + + # If an edge is emitted in a subgraph, it drags its nodes into the + # subgraph, too. Not what we want. + for edge in deferred_send_edges: + emit(edge) # Emit intra-partition edges for array, node in part_node_to_info.items(): diff --git a/test/test_distributed.py b/test/test_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac44ad4c158cf64d761e6dd7dcb5fc1027c3920 --- /dev/null +++ b/test/test_distributed.py @@ -0,0 +1,221 @@ +__copyright__ = """Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl as pytest_generate_tests) +import pyopencl as cl +import numpy as np +import pytato as pt +import sys +import os + +from pytato.distributed import (staple_distributed_send, make_distributed_recv, + find_distributed_partition, + execute_distributed_partition, number_distributed_tags) + +from pytato.partition import generate_code_for_partition + + +# {{{ mpi test infrastructure + +def run_test_with_mpi(num_ranks, f, *args): + import pytest + pytest.importorskip("mpi4py") + + from pickle import dumps + from base64 import b64encode + + invocation_info = b64encode(dumps((f, args))).decode() + from subprocess import check_call + + # NOTE: CI uses OpenMPI; -x to pass env vars. MPICH uses -env + check_call([ + "mpiexec", "-np", str(num_ranks), + "-x", "RUN_WITHIN_MPI=1", + "-x", f"INVOCATION_INFO={invocation_info}", + sys.executable, __file__]) + + +def run_test_with_mpi_inner(): + from pickle import loads + from base64 import b64decode + f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode())) + + f(cl.create_some_context, *args) + +# }}} + + +# {{{ "basic" test (similar to distributed example) + +def test_distributed_execution_basic(): + run_test_with_mpi(2, _do_test_distributed_execution_basic) + + +def _do_test_distributed_execution_basic(ctx_factory): + from mpi4py import MPI # pylint: disable=import-error + comm = MPI.COMM_WORLD + + rank = comm.Get_rank() + size = comm.Get_size() + + rng = np.random.default_rng(seed=27) + + x_in = rng.integers(100, size=(4, 4)) + x = pt.make_data_wrapper(x_in) + + halo = staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=42, + stapled_to=make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=42, shape=(4, 4), dtype=int)) + + y = x+halo + + # Find the partition + outputs = pt.DictOfNamedArrays({"out": y}) + distributed_parts = find_distributed_partition(outputs) + prg_per_partition = generate_code_for_partition(distributed_parts) + + # Execute the distributed partition + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + context = execute_distributed_partition(distributed_parts, prg_per_partition, + queue, comm) + + final_res = context["out"].get(queue) + + # All ranks generate the same random numbers (same seed). + np.testing.assert_allclose(x_in*2, final_res) + +# }}} + + +# {{{ test based on random dag + +def test_distributed_execution_random_dag(): + run_test_with_mpi(2, _do_test_distributed_execution_random_dag) + + +class _RandomDAGTag: + pass + + +def _do_test_distributed_execution_random_dag(ctx_factory): + from mpi4py import MPI # pylint: disable=import-error + comm = MPI.COMM_WORLD + + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + rank = comm.Get_rank() + size = comm.Get_size() + + from testlib import RandomDAGContext, make_random_dag + + axis_len = 4 + comm_fake_prob = 500 + + gen_comm_called = False + + ntests = 10 + for i in range(ntests): + seed = 120 + i + print(f"Step {i} {seed}") + + # {{{ compute value with communication + + comm_tag = 17 + + def gen_comm(rdagc): + nonlocal gen_comm_called + gen_comm_called = True + + nonlocal comm_tag + comm_tag += 1 + tag = (comm_tag, _RandomDAGTag) + + inner = make_random_dag(rdagc) + return staple_distributed_send( + inner, dest_rank=(rank-1) % size, comm_tag=tag, + stapled_to=make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=tag, + shape=inner.shape, dtype=inner.dtype)) + + rdagc_comm = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False, + additional_generators=[ + (comm_fake_prob, gen_comm) + ]) + x_comm = make_random_dag(rdagc_comm) + + distributed_partition = find_distributed_partition( + pt.DictOfNamedArrays({"result": x_comm})) + + # Transform symbolic tags into numeric ones for MPI + distributed_partition, _new_mpi_base_tag = number_distributed_tags( + comm, + distributed_partition, + base_tag=comm_tag) + + prg_per_partition = generate_code_for_partition(distributed_partition) + + context = execute_distributed_partition( + distributed_partition, prg_per_partition, queue, comm) + + res_comm = context["result"] + + # }}} + + # {{{ compute ref value without communication + + # compiled evaluation (i.e. use_numpy=False) fails for some of these + # graphs, cf. https://github.com/inducer/pytato/pull/255 + rdagc_no_comm = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=True, + additional_generators=[ + (comm_fake_prob, lambda rdagc: make_random_dag(rdagc)) + ]) + res_no_comm_numpy = make_random_dag(rdagc_no_comm) + + # }}} + + if not isinstance(res_comm, np.ndarray): + res_comm = res_comm.get(queue=queue) + + np.testing.assert_allclose(res_comm, res_no_comm_numpy) + + assert gen_comm_called + +# }}} + + +if __name__ == "__main__": + if "RUN_WITHIN_MPI" in os.environ: + run_test_with_mpi_inner() + elif len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) + +# vim: foldmethod=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 43b8188f9f40cb9443ed9bf3d47a6cb9fd0871db..f85bff9d5f1ea4ff4c716a0fd37fd7f5ca6b5c53 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -638,7 +638,7 @@ def test_rec_get_user_nodes_linear_complexity(): assert (expected_result == result) -def test_tag_child_nodes_linear_complexity(): +def test_tag_user_nodes_linear_complexity(): from numpy.random import default_rng def construct_intestine_graph(depth=100, seed=0): @@ -663,7 +663,7 @@ def test_tag_child_nodes_linear_complexity(): expected_result[expr] = {"foo"} expr, inp = construct_intestine_graph() - result = pt.transform.tag_child_nodes(user_collector.node_to_users, "foo", inp) + result = pt.transform.tag_user_nodes(user_collector.node_to_users, "foo", inp) ExpectedResultComputer()(expr) assert expected_result == result