From 2cfeb0eeb2464ae9f17e15af5a22d99c6a5073ac Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 8 Feb 2022 22:55:44 -0600 Subject: [PATCH] Distributed v3 (#148) * add test * use are_shape_components_equal Co-authored-by: Kaushik Kulkarni <15399010+kaushikcfd@users.noreply.github.com> * fix * Add random graph generation * push current state * Random DAG generator actually tests things * Fix show_dot_graph invocation that tripped up pylint * Partitioner: Partitions exist even when they have no in/out edges * Partitioner: Refactor disjointness check into separate function * get_dot_graph_from_partitions: Do not require toposorted partition list * test_partitioner: properly use make_random_dag * Move transform functionality to separate file, fix partition type annotations * Use PartitionId type alias for partition IDs * Add pytato.partition to docs * Fix flake8 in partition * Code folding in pytato.partition * Partition: rename _handle_new_binding -> _handle_parent_child * Split EdgeCachedMapper out of graph partitioner * EdgeCachedMapper: support *args * Remove beginnings of _PartitionSplitter * Add PartitionInducedCycleError, bail on test if cycle encountered * remove DictOfNamedArrays.__str__ * add GraphToDictMapper * remove *args, derive from CachedMapper * rename to Userscollector, add index map functions * flake8 * add random DAG test * better doc * flake8 * remove changes to get_dot_graph * remove spurious array.py changes * add missing distributed.py * misc fixes * doc fix * fix equality * Update array.py * Update partition.py * Update __init__.py * misc type fixes * rename * mypy * make example run again * fix and rename tag_child_nodes * add test * flake8 * use ints * dont duplicate input arrays * make ArrayToDotNodeInfoMapper a cached mapper * make ArrayToDotNodeInfoMapper a cached mapper * fixes * fix get_dot_graph_from_partitions * various constructor fixes * inputs fix * MPI fixes * comm passing * DistributedSend: remove shape,dtype, Recv: remove data * Partitioned vis: only emit Placeholders once * Hook distributed docs into main docs * Refactor distributed for DistributedSendRefHolder * misc fixes * add a few sanity checks * fix reverse_graph * flake8 * fix recv hang Previously, send was stapled to recv, and both were in the same partition. This meant that first the irecv was waited on before the corresponing send, leading to a deadlock. * Revert "fix recv hang" This reverts commit ba374bc91c60cb33374e4aa7034024a8f18818b1. * Rename DistributedSendRefHolder.{data->passthrough_data} * UsersCollector: Collect DistributedSend as a separate user * Add ArrayToDotNodeInfoMapper.map_distributed_send_ref_holder to appropriately traverse send * Add nitpicky style FIXMEs to CodePartitions * Drop _DistributedCommReplacer from generate_code_for_partitions * Make DistributedSend a Hashable * make_distributed_recv: normalize dtype * Create DistributedGraphPartitions, rework representation of distributed graphs * Teach the graph visualizer how to show DistributedGraphPartitions * Adapt distributed example to new distributed graph data structure * rename CodePartitions to GraphPartitions * fixup CodePartitions rename * mypy fixes * bail before execution * make recv part of partition * make addr part of fields * doc fix * Revert "make recv part of partition" This reverts commit 8de5a686b43159a3fd19730cf09c79002c7f4e22. * Refactor partitioning to use fewer dicts, preserve partial part order * Track fewer-dicts refactor of partitioner in distributed * Dynamically select ready parts in execute_partition_distributed * Rename _GraphPartitioner.{seen_partition_ids->seen_part_ids} * Partitioner: more partition -> part renaming * Mypy-clean distributed execution * Partitioned Vis: Emit non-Placeholder input arrays inside their partitions * Make map_distributed_send_ref_holder an abstract method of EdgeCachedMapper * find_partition: Use better var name for _GraphPartitioner * Teach find_partition to handle DistributedSend (fixable design fail) * gather_distributed_comm_info: Handle sends extracted by partitioner * Distributed example: make get_part_id a nested function * Distributed exec: use non-blocking send, wait for send request completion * Fix direction of Part.needed_pids * Distribted exec: ready_pids: actually contain pids * Distributed exec: fix minor bugs, distributed example works * Remove abort from distributed example app * Placate flake8 about distributed example * Refactor partitioning to use fewer dicts, preserve partial part order * Rename _GraphPartitioner.{seen_partition_ids->seen_part_ids} * Partitioner: more partition -> part renaming * Partitioned Vis: Emit non-Placeholder input arrays inside their partitions * Test get_dot_graph_from_partition as part of test_partitioner * find_partition: Use better var name for _GraphPartitioner * Fix direction of Part.needed_pids * Fix find_partition * Fix doc warnings * Fix find_partition * Fix doc warnings * lint fixes * run CI with multiple ranks * mpi fix * another ci fix * ci fix * work around doc build failure * better test for example + cleanup * extract find_partition_distributed * add basic pytest * add random dag test * change comm tag * add first pass comment * fix get_dot_graph_from_partition doc * export staple_distributed_send * expoet more functions * add missing map_loopy_call * Change canonical import location for distributed functionality * Re-break circular imports for doc build in pytato.partition * add comment to doc * Use MRO to find Array mapper methods * Use a separate mapper method for LoopyCallResult * Visualization: Skip data in DataWrapper * Visualization: Support visualization of LoopyCall, DictOfNamedArray * Partition: gather user_input_names for each part * Partitioned visualization: Do not mishandle Placeholders from user input * lint fixes * simplify pid_to_user_input_names * spelling * Fix partition/vis type annotations * Ensure ph names are unique across partitions in _DistributedCommReplacer * Dist: Rename *_partition_distributed -> *_distributed_partition, drop gather_comm_info * Clarify, rename GraphPart.{user,partition}_input_names * send numpy data * receive cl buffer * small name change * add fixme * Distributed receive: Use cl.array.to_device * Fix DistributedSend.copy: args are optional * Fix DistributedRecv._fields: shape and dtype were missing * Distributed: rename arguments comm -> mpi_communicator * Distributed: Support symbolic tags, add number_distributed_tags * fix running with a single rank * opencl/numpy arg fixes * lint fixes * another doc fix * frozenset in number_tags * Refactor _GraphPartitioner/find_partition so that it does not know about distributed_sends * attempt to address axes changes * better checking for non-existing recvs * rename _gather_distributed_comm_info arg * simplify getting output * mypy fix * show hex id * fix walkmapper for dist recv * better LoopyCall visualization * debug * add missing axes * assert that we are giving find_distributed_partition not too many outputs * fix merge error * disable some debug * flake * less strict disjoint checking * lint fixes * fixes * another one * add to stringifier * support multiple outputs in find_distributed_partition * reprifier * undo disable disjoint check * undo check disable * flake8 * better test * Avoid type-ignores in single-rank case of execute_distributed_partition * find_partition: avoid passing a partitioner instance * Revert some tag/einsum merge accidents * simplify recv_names* * Random DAG generator: Add support for user-supplied 'additional_generators' * Distributed tests: actually run with MPI, fix 'basic' test, use comm nodes in random DAG * use number_distributed_tags * document partitioner class * add comment regarding renumbering * refactor random comm generation * restore currentmodule:: pytato.tags (was this removed accidentally?) * fix doc build * fix flake8 * remove TODO (attributes are kind of obvious now) * cleanup _check_partition_disjointness * cleanup distributed.py imports * fix doc build * add NodeCountMapper * fix doc * Document GraphPartitioner interface * Revert "refactor random comm generation" This reverts commit 4fd2dd1873302a648a8168a745cfcd0630a319b5. * Add a comment explaining index reversal in distributed exec * _do_test_distributed_execution_random_dag: Use numpy for reference result Compiled evaluation for these graphs seems to compute incorrect results, see gh-255. Co-authored-by: Kaushik Kulkarni <15399010+kaushikcfd@users.noreply.github.com> Co-authored-by: [6~ --- .github/workflows/ci.yml | 2 + .test-conda-env-py3.yml | 1 + doc/dag.rst | 5 + doc/design.rst | 3 + examples/distributed.py | 64 ++++ pytato/__init__.py | 19 +- pytato/distributed.py | 675 +++++++++++++++++++++++++++++++++++++++ pytato/equality.py | 21 ++ pytato/partition.py | 188 +++++++---- pytato/stringifier.py | 2 + pytato/transform.py | 118 ++++++- pytato/visualization.py | 119 +++++-- test/test_distributed.py | 221 +++++++++++++ test/test_pytato.py | 4 +- 14 files changed, 1339 insertions(+), 103 deletions(-) create mode 100644 examples/distributed.py create mode 100644 pytato/distributed.py create mode 100644 test/test_distributed.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a1ae714..f098a94 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -79,6 +79,8 @@ jobs: build_py_project_in_conda_env run_examples + mpirun -n 2 python distributed.py + docs: name: Documentation runs-on: ubuntu-latest diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 7c34e94..6459d4c 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -11,3 +11,4 @@ dependencies: - pyopencl - islpy - sphinx-autodoc-typehints +- mpi4py diff --git a/doc/dag.rst b/doc/dag.rst index b6e7e7d..f4874f1 100644 --- a/doc/dag.rst +++ b/doc/dag.rst @@ -28,6 +28,11 @@ Partitioning Array Expression Graphs .. automodule:: pytato.partition +Support for Distributed-Memory/Message Passing +============================================== + +.. automodule:: pytato.distributed + Utilities and Diagnostics ========================= diff --git a/doc/design.rst b/doc/design.rst index 71984d4..f402a93 100644 --- a/doc/design.rst +++ b/doc/design.rst @@ -146,6 +146,9 @@ Reserved Identifiers transport across parts of a partitioned DAG (cf. :func:`~pytato.partition.find_partition`). + - ``_pt_dist``: Used to automatically generate identifiers for + placeholders at distributed communication boundaries. + - Identifiers used in index lambdas are also reserved. These include: - Identifiers matching the regular expression ``_[0-9]+``. They are used diff --git a/examples/distributed.py b/examples/distributed.py new file mode 100644 index 0000000..1f12839 --- /dev/null +++ b/examples/distributed.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python + +from mpi4py import MPI # pylint: disable=import-error +comm = MPI.COMM_WORLD + +import pytato as pt +import pyopencl as cl +import numpy as np + +from pytato import (find_distributed_partition, generate_code_for_partition, + number_distributed_tags, + execute_distributed_partition, + staple_distributed_send, make_distributed_recv) + + +def main(): + rank = comm.Get_rank() + size = comm.Get_size() + rng = np.random.default_rng() + + x_in = rng.integers(100, size=(4, 4)) + x = pt.make_data_wrapper(x_in) + + mytag = (main, "x") + halo = staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=mytag, + stapled_to=make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=mytag, shape=(4, 4), dtype=int)) + + y = x+halo + + # Find the partition + outputs = pt.DictOfNamedArrays({"out": y}) + distributed_parts = find_distributed_partition(outputs) + distributed_parts, _ = number_distributed_tags( + comm, distributed_parts, base_tag=42) + prg_per_partition = generate_code_for_partition(distributed_parts) + + if 0: + from pytato.visualization import show_dot_graph + show_dot_graph(distributed_parts) + + # Sanity check + from pytato.visualization import get_dot_graph_from_partition + get_dot_graph_from_partition(distributed_parts) + + # Execute the distributed partition + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + context = execute_distributed_partition(distributed_parts, prg_per_partition, + queue, comm) + + final_res = context["out"].get(queue) + + ref_res = comm.bcast(final_res) + + np.testing.assert_allclose(ref_res, final_res) + + if rank == 0: + print("Distributed test succeeded.") + + +if __name__ == "__main__": + main() diff --git a/pytato/__init__.py b/pytato/__init__.py index 97444b6..6582dbe 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -32,7 +32,6 @@ from pytato.array import ( AdvancedIndexInContiguousAxes, AdvancedIndexInNoncontiguousAxes, SizeParam, Axis, - make_dict_of_named_arrays, make_placeholder, make_size_param, make_data_wrapper, einsum, @@ -70,6 +69,15 @@ from pytato.visualization import (get_dot_graph, show_dot_graph, import pytato.analysis as analysis import pytato.tags as tags import pytato.transform as transform +from pytato.distributed import (make_distributed_send, make_distributed_recv, + DistributedRecv, DistributedSend, + DistributedSendRefHolder, + staple_distributed_send, + find_distributed_partition, + number_distributed_tags, + execute_distributed_partition) + +from pytato.partition import generate_code_for_partition __all__ = ( "Array", "AbstractResultWithNamedArrays", "DictOfNamedArrays", @@ -111,6 +119,15 @@ __all__ = ( "broadcast_to", + "make_distributed_recv", "make_distributed_send", "DistributedRecv", + "DistributedSend", "staple_distributed_send", "DistributedSendRefHolder", + + "find_distributed_partition", + "number_distributed_tags", + "execute_distributed_partition", + + "generate_code_for_partition", + # sub-modules "analysis", "tags", "transform", diff --git a/pytato/distributed.py b/pytato/distributed.py new file mode 100644 index 0000000..6624d6b --- /dev/null +++ b/pytato/distributed.py @@ -0,0 +1,675 @@ +from __future__ import annotations + +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import (Any, Dict, Hashable, Tuple, Optional, Set, # noqa: F401 + List, FrozenSet, Callable, cast, Mapping) # Mapping required by sphinx + +from dataclasses import dataclass + +from pytools import UniqueNameGenerator +from pytools.tag import Taggable, TagsType +from pytato.array import (Array, _SuppliedShapeAndDtypeMixin, + DictOfNamedArrays, ShapeType, + Placeholder, make_placeholder, + _get_default_axes, AxesT) +from pytato.transform import ArrayOrNames, CopyMapper +from pytato.partition import GraphPart, GraphPartition, PartId, GraphPartitioner +from pytato.target import BoundProgram + +import numpy as np + +__doc__ = r""" +Distributed communication +------------------------- + +.. currentmodule:: pytato +.. autoclass:: DistributedSend +.. autoclass:: DistributedSendRefHolder +.. autoclass:: DistributedRecv + +.. autofunction:: make_distributed_send +.. autofunction:: staple_distributed_send +.. autofunction:: make_distributed_recv + +.. currentmodule:: pytato.distributed + +.. autoclass:: DistributedGraphPart +.. autoclass:: DistributedGraphPartition + +.. currentmodule:: pytato + +.. autofunction:: find_distributed_partition +.. autofunction:: number_distributed_tags +.. autofunction:: execute_distributed_partition + +Internal stuff that is only here because the documentation tool wants it +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. class:: CommTagType + + A type representing a communication tag. + +.. class:: TagsType + + A :class:`frozenset` of :class:`pytools.tag.Tag`\ s. + +.. class:: ShapeType + + A type representing a shape. + +.. class:: AxesT + + A :class:`tuple` of :class:`Axis` objects. +""" + + +# {{{ Distributed node types + +CommTagType = Hashable + + +class DistributedSend(Taggable): + """Class representing a distributed send operation. + + .. attribute:: data + + The :class:`~pytato.Array` to be sent. + + .. attribute:: dest_rank + + An :class:`int`. The rank to which :attr:`data` is to be sent. + + .. attribute:: comm_tag + + A hashable, picklable object to serve as a 'tag' for the communication. + Only a :class:`DistributedRecv` with the same tag will be able to + receive the data being sent here. + """ + + def __init__(self, data: Array, dest_rank: int, comm_tag: CommTagType, + tags: TagsType = frozenset()) -> None: + super().__init__(tags=tags) + self.data = data + self.dest_rank = dest_rank + self.comm_tag = comm_tag + + def __hash__(self) -> int: + return ( + hash(self.__class__) + ^ hash(self.data) + ^ hash(self.dest_rank) + ^ hash(self.comm_tag) + ^ hash(self.tags) + ) + + def __eq__(self, other: Any) -> bool: + return ( + self.__class__ is other.__class__ + and self.data == other.data + and self.dest_rank == other.dest_rank + and self.comm_tag == other.comm_tag + and self.tags == other.tags) + + def copy(self, **kwargs: Any) -> DistributedSend: + data: Optional[Array] = kwargs.get("data") + dest_rank: Optional[int] = kwargs.get("dest_rank") + comm_tag: Optional[CommTagType] = kwargs.get("comm_tag") + tags: Optional[TagsType] = kwargs.get("tags") + return type(self)( + data=data or self.data, + dest_rank=dest_rank if dest_rank is not None else self.dest_rank, + comm_tag=comm_tag if comm_tag is not None else self.comm_tag, + tags=tags if tags is not None else self.tags) + + +class DistributedSendRefHolder(Array): + """A node acting as an identity on :attr:`passthrough_data` while also holding + a reference to a :class:`DistributedSend` in :attr:`send`. Since + :mod:`pytato` represents data flow, and since no data flows 'out' + of a :class:`DistributedSend`, no node in all of :mod:`pytato` has + a good reason to hold a reference to a send node, since there is + no useful result of a send (at least of of an :class:`~pytato.Array` type). + + This is where this node type comes in. Its value is the same as that of + :attr:`passthrough_data`, *and* it holds a reference to the send node. + + .. note:: + + This all seems a wee bit inelegant, but nobody who has written + or reviewed this code so far had a better idea. If you do, please speak up! + + .. attribute:: send + + The :class:`DistributedSend` to which a reference is to be held. + + .. attribute:: passthrough_data + + A :class:`~pytato.Array`. The value of this node. + """ + + _mapper_method = "map_distributed_send_ref_holder" + _fields = Array._fields + ("passthrough_data", "send",) + + def __init__(self, send: DistributedSend, passthrough_data: Array, + tags: TagsType = frozenset()) -> None: + super().__init__(axes=passthrough_data.axes, tags=tags) + self.send = send + self.passthrough_data = passthrough_data + + @property + def shape(self) -> ShapeType: + return self.passthrough_data.shape + + @property + def dtype(self) -> np.dtype[Any]: + return self.passthrough_data.dtype + + +class DistributedRecv(_SuppliedShapeAndDtypeMixin, Array): + """Class representing a distributed receive operation. + + .. attribute:: src_rank + + An :class:`int`. The rank from which an array is to be received. + + .. attribute:: comm_tag + + A hashable, picklable object to serve as a 'tag' for the communication. + Only a :class:`DistributedRecv` with the same tag will be able to + receive the data being sent here. + + .. attribute:: shape + .. attribute:: dtype + """ + + _fields = Array._fields + ("shape", "dtype", "src_rank", "comm_tag") + _mapper_method = "map_distributed_recv" + + def __init__(self, src_rank: int, comm_tag: CommTagType, + shape: ShapeType, dtype: Any, + tags: Optional[TagsType] = frozenset(), + axes: Optional[AxesT] = None) -> None: + + if not axes: + axes = _get_default_axes(len(shape)) + super().__init__(shape=shape, dtype=dtype, tags=tags, + axes=axes) + self.src_rank = src_rank + self.comm_tag = comm_tag + + +def make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType, + send_tags: TagsType = frozenset()) -> \ + DistributedSend: + return DistributedSend(sent_data, dest_rank, comm_tag, send_tags) + + +def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType, + stapled_to: Array, *, + send_tags: TagsType = frozenset(), + ref_holder_tags: TagsType = frozenset()) -> \ + DistributedSendRefHolder: + return DistributedSendRefHolder( + DistributedSend(sent_data, dest_rank, comm_tag, send_tags), + stapled_to, tags=ref_holder_tags) + + +def make_distributed_recv(src_rank: int, comm_tag: CommTagType, + shape: ShapeType, dtype: Any, + tags: TagsType = frozenset()) \ + -> DistributedRecv: + dtype = np.dtype(dtype) + return DistributedRecv(src_rank, comm_tag, shape, dtype, tags) + +# }}} + + +# {{{ distributed info collection + +@dataclass(frozen=True) +class DistributedGraphPart(GraphPart): + """For one graph partition, record send/receive information for input/ + output names. + + .. attribute:: input_name_to_recv_node + .. attribute:: output_name_to_send_node + .. attribute:: distributed_sends + """ + input_name_to_recv_node: Dict[str, DistributedRecv] + output_name_to_send_node: Dict[str, DistributedSend] + distributed_sends: List[DistributedSend] + + +@dataclass(frozen=True) +class DistributedGraphPartition(GraphPartition): + """Store information about distributed graph partitions. This + has the same attributes as :class:`~pytato.partition.GraphPartition`, + however :attr:`~pytato.partition.GraphPartition.parts` now maps to + instances of :class:`DistributedGraphPart`. + """ + parts: Dict[PartId, DistributedGraphPart] + + +class _DistributedCommReplacer(CopyMapper): + """Mapper to process a DAG for realization of :class:`DistributedSend` + and :class:`DistributedRecv` outside of normal code generation. + + - Replaces :class:`DistributedRecv` with :class`~pytato.Placeholder` + so that received data can be externally supplied, making a note + in :attr:`input_name_to_recv_node`. + + - Eliminates :class:`DistributedSendRefHolder` and + :class:`DistributedSend` from the DAG, making a note of data + to be send in :attr:`output_name_to_send_node`. + """ + + def __init__(self, dist_name_generator: UniqueNameGenerator) -> None: + super().__init__() + + self.name_generator = dist_name_generator + + self.input_name_to_recv_node: Dict[str, DistributedRecv] = {} + self.output_name_to_send_node: Dict[str, DistributedSend] = {} + + def map_distributed_recv(self, expr: DistributedRecv) -> Placeholder: + # no children, no need to recurse + + new_name = self.name_generator() + self.input_name_to_recv_node[new_name] = expr + return make_placeholder(new_name, self.rec_idx_or_size_tuple(expr.shape), + expr.dtype, tags=expr.tags) + + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder) -> Array: + raise ValueError("DistributedSendRefHolder should not occur in partitioned " + "graphs") + + def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: + new_send = DistributedSend( + data=self.rec(expr.data), + dest_rank=expr.dest_rank, + comm_tag=expr.comm_tag, + tags=expr.tags) + + new_name = self.name_generator() + self.output_name_to_send_node[new_name] = new_send + + return new_send + + +def _gather_distributed_comm_info(partition: GraphPartition, + pid_to_distributed_sends: Dict[PartId, List[DistributedSend]]) -> \ + DistributedGraphPartition: + var_name_to_result = {} + parts: Dict[PartId, DistributedGraphPart] = {} + + dist_name_generator = UniqueNameGenerator(forced_prefix="_pt_dist_") + + for part in partition.parts.values(): + comm_replacer = _DistributedCommReplacer(dist_name_generator) + part_results = { + var_name: comm_replacer(partition.var_name_to_result[var_name]) + for var_name in part.output_names} + + dist_sends = [ + comm_replacer.map_distributed_send(send) + for send in pid_to_distributed_sends.get(part.pid, [])] + + part_results.update({ + name: send_node.data + for name, send_node in + comm_replacer.output_name_to_send_node.items()}) + + parts[part.pid] = DistributedGraphPart( + pid=part.pid, + needed_pids=part.needed_pids, + user_input_names=part.user_input_names, + partition_input_names=(part.partition_input_names + | frozenset(comm_replacer.input_name_to_recv_node)), + output_names=(part.output_names + | frozenset(comm_replacer.output_name_to_send_node)), + distributed_sends=dist_sends, + + input_name_to_recv_node=comm_replacer.input_name_to_recv_node, + output_name_to_send_node=comm_replacer.output_name_to_send_node) + + for name, val in part_results.items(): + assert name not in var_name_to_result + var_name_to_result[name] = val + + result = DistributedGraphPartition( + parts=parts, + var_name_to_result=var_name_to_result, + toposorted_part_ids=partition.toposorted_part_ids) + + if __debug__: + # Check disjointness again since we replaced a few nodes. + from pytato.partition import _check_partition_disjointness + _check_partition_disjointness(result) + + return result + +# }}} + + +# {{{ find distributed partition + +@dataclass(frozen=True, eq=True) +class DistributedPartitionId(): + fed_sends: object + feeding_recvs: object + + +class _DistributedGraphPartitioner(GraphPartitioner): + + def __init__(self, get_part_id: Callable[[ArrayOrNames], PartId]) -> None: + super().__init__(get_part_id) + self.pid_to_dist_sends: Dict[PartId, List[DistributedSend]] = {} + + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, *args: Any) -> Any: + send_part_id = self.get_part_id(expr.send.data) + + from pytato.distributed import DistributedSend + self.pid_to_dist_sends.setdefault(send_part_id, []).append( + DistributedSend( + data=self.rec(expr.send.data), + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag, + tags=expr.send.tags)) + + return self.rec(expr.passthrough_data) + + def make_partition(self, outputs: DictOfNamedArrays) \ + -> DistributedGraphPartition: + + partition = super().make_partition(outputs) + return _gather_distributed_comm_info(partition, self.pid_to_dist_sends) + + +def find_distributed_partition( + outputs: DictOfNamedArrays) -> DistributedGraphPartition: + """Finds a partitioning in a distributed context.""" + + from pytato.transform import (UsersCollector, TopoSortMapper, + reverse_graph, tag_user_nodes) + + gdm = UsersCollector() + gdm(outputs) + + graph = gdm.node_to_users + + # type-ignore-reason: + # 'graph' also includes DistributedSend nodes, which are not Arrays + rev_graph = reverse_graph(graph) # type: ignore[arg-type] + + # FIXME: Inefficient... too many traversals + node_to_feeding_recvs: Dict[ArrayOrNames, Set[ArrayOrNames]] = {} + for node in graph: + node_to_feeding_recvs.setdefault(node, set()) + if isinstance(node, DistributedRecv): + tag_user_nodes(graph, tag=node, # type: ignore[arg-type] + starting_point=node, + node_to_tags=node_to_feeding_recvs) + + node_to_fed_sends: Dict[ArrayOrNames, Set[ArrayOrNames]] = {} + for node in rev_graph: + node_to_fed_sends.setdefault(node, set()) + if isinstance(node, DistributedSend): + tag_user_nodes(rev_graph, tag=node, starting_point=node, + node_to_tags=node_to_fed_sends) + + def get_part_id(expr: ArrayOrNames) -> DistributedPartitionId: + return DistributedPartitionId(frozenset(node_to_fed_sends[expr]), + frozenset(node_to_feeding_recvs[expr])) + + # {{{ Sanity checks + + if __debug__: + for node, _ in node_to_feeding_recvs.items(): + for n in node_to_feeding_recvs[node]: + assert(isinstance(n, DistributedRecv)) + + for node, _ in node_to_fed_sends.items(): + for n in node_to_fed_sends[node]: + assert(isinstance(n, DistributedSend)) + + tm = TopoSortMapper() + tm(outputs) + + for node in tm.topological_order: + get_part_id(node) + + # }}} + + from pytato.partition import find_partition + return cast(DistributedGraphPartition, + find_partition(outputs, get_part_id, _DistributedGraphPartitioner)) + +# }}} + + +# {{{ construct tag numbering + +def number_distributed_tags( + mpi_communicator: Any, + partition: DistributedGraphPartition, + base_tag: int) -> Tuple[DistributedGraphPartition, int]: + """Return a new :class:`~pytato.distributed.DistributedGraphPartition` + in which symbolic tags are replaced by unique integer tags, created by + counting upward from *base_tag*. + + :returns: a tuple ``(partition, next_tag)``, where *partition* is the new + :class:`~pytato.distributed.DistributedGraphPartition` with numerical + tags, and *next_tag* is the lowest integer tag above *base_tag* that + was not used. + + .. note:: + + This is a potentially heavyweight MPI-collective operation on + *mpi_communicator*. + """ + tags = frozenset({ + recv.comm_tag + for part in partition.parts.values() + for name, recv in part.input_name_to_recv_node.items() + } | { + send.comm_tag + for part in partition.parts.values() + for name, send in part.output_name_to_send_node.items()}) + + from mpi4py import MPI + + def set_union( + set_a: FrozenSet[Any], set_b: FrozenSet[Any], + mpi_data_type: MPI.Datatype) -> FrozenSet[str]: + assert mpi_data_type is None + assert isinstance(set_a, frozenset) + assert isinstance(set_b, frozenset) + + return set_a | set_b + + root_rank = 0 + + set_union_mpi_op = MPI.Op.Create( + # type ignore reason: mpi4py misdeclares op functions as returning + # None. + set_union, # type: ignore[arg-type] + commute=True) + try: + all_tags = mpi_communicator.reduce( + tags, set_union_mpi_op, root=root_rank) + finally: + set_union_mpi_op.Free() + + if mpi_communicator.rank == root_rank: + sym_tag_to_int_tag = {} + next_tag = base_tag + for sym_tag in all_tags: + sym_tag_to_int_tag[sym_tag] = next_tag + next_tag += 1 + + mpi_communicator.bcast((sym_tag_to_int_tag, next_tag), root=root_rank) + else: + sym_tag_to_int_tag, next_tag = mpi_communicator.bcast(None, root=root_rank) + + from dataclasses import replace + return DistributedGraphPartition( + parts={ + pid: replace(part, + input_name_to_recv_node={ + name: recv.copy(comm_tag=sym_tag_to_int_tag[recv.comm_tag]) + for name, recv in part.input_name_to_recv_node.items()}, + output_name_to_send_node={ + name: send.copy(comm_tag=sym_tag_to_int_tag[send.comm_tag]) + for name, send in part.output_name_to_send_node.items()}, + ) + for pid, part in partition.parts.items() + }, + var_name_to_result=partition.var_name_to_result, + toposorted_part_ids=partition.toposorted_part_ids), next_tag + +# }}} + + +# {{{ distributed execute + +def _post_receive(mpi_communicator: Any, + recv: DistributedRecv) -> Tuple[Any, np.ndarray[Any, Any]]: + # FIXME: recv.shape might be parametric, evaluate + buf = np.empty(recv.shape, dtype=recv.dtype) + + return mpi_communicator.Irecv( + buf=buf, source=recv.src_rank, tag=recv.comm_tag), buf + + +def _mpi_send(mpi_communicator: Any, send_node: DistributedSend, + data: np.ndarray[Any, Any]) -> Any: + # Must use-non-blocking send, as blocking send may wait for a corresponding + # receive to be posted (but if sending to self, this may only occur later). + return mpi_communicator.Isend( + data, dest=send_node.dest_rank, tag=send_node.comm_tag) + + +def execute_distributed_partition( + partition: DistributedGraphPartition, prg_per_partition: + Dict[Hashable, BoundProgram], + queue: Any, mpi_communicator: Any, + input_args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + + if input_args is None: + input_args = {} + + from mpi4py import MPI + + if len(partition.parts) != 1: + recv_names_tup, recv_requests_tup, recv_buffers_tup = zip(*[ + (name,) + _post_receive(mpi_communicator, recv) + for part in partition.parts.values() + for name, recv in part.input_name_to_recv_node.items()]) + recv_names = list(recv_names_tup) + recv_requests = list(recv_requests_tup) + recv_buffers = list(recv_buffers_tup) + del recv_names_tup + del recv_requests_tup + del recv_buffers_tup + else: + # Only a single partition, no recv requests exist + recv_names = [] + recv_requests = [] + recv_buffers = [] + + context: Dict[str, Any] = input_args.copy() + + pids_to_execute = set(partition.parts) + pids_executed = set() + recv_names_completed = set() + send_requests = [] + + def exec_ready_part(part: DistributedGraphPart) -> None: + inputs = {k: context[k] for k in part.all_input_names()} + + _evt, result_dict = prg_per_partition[part.pid](queue, **inputs) + + context.update(result_dict) + + for name, send_node in part.output_name_to_send_node.items(): + # FIXME: pytato shouldn't depend on pyopencl + if isinstance(context[name], np.ndarray): + data = context[name] + else: + data = context[name].get(queue) + send_requests.append(_mpi_send(mpi_communicator, send_node, data)) + + pids_executed.add(part.pid) + pids_to_execute.remove(part.pid) + + def wait_for_some_recvs() -> None: + complete_recv_indices = MPI.Request.Waitsome(recv_requests) + + # Waitsome is allowed to return None + if not complete_recv_indices: + complete_recv_indices = [] + + # reverse to preserve indices + for idx in sorted(complete_recv_indices, reverse=True): + name = recv_names.pop(idx) + recv_requests.pop(idx) + buf = recv_buffers.pop(idx) + + # FIXME: pytato shouldn't depend on pyopencl + import pyopencl as cl + context[name] = cl.array.to_device(queue, buf) + recv_names_completed.add(name) + + # FIXME: This keeps all variables alive that are used to get data into + # and out of partitions. Probably not what we want long-term. + + # {{{ main loop + + while pids_to_execute: + ready_pids = {pid + for pid in pids_to_execute + # FIXME: Only O(n**2) altogether. Nobody is going to notice, right? + if partition.parts[pid].needed_pids <= pids_executed + and (set(partition.parts[pid].input_name_to_recv_node) + <= recv_names_completed)} + for pid in ready_pids: + exec_ready_part(partition.parts[pid]) + + if not ready_pids: + wait_for_some_recvs() + + # }}} + + for send_req in send_requests: + send_req.Wait() + + return context + +# }}} + +# vim: foldmethod=marker diff --git a/pytato/equality.py b/pytato/equality.py index 0bfdb62..984d1f7 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -34,6 +34,7 @@ from pytato.array import (AdvancedIndexInContiguousAxes, if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult + from pytato.distributed import DistributedRecv, DistributedSendRefHolder __doc__ = """ .. autoclass:: EqualityComparer @@ -245,6 +246,26 @@ class EqualityComparer: and all(self.rec(expr1._data[name], expr2._data[name]) for name in expr1._data)) + def map_distributed_send_ref_holder( + self, expr1: DistributedSendRefHolder, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.rec(expr1.send.data, expr2.send.data) + and self.rec(expr1.passthrough_data, expr2.passthrough_data) + and expr1.send.dest_rank == expr2.send.dest_rank + and expr1.send.comm_tag == expr2.send.comm_tag + and expr1.send.tags == expr2.send.tags + and expr1.tags == expr2.tags + ) + + def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.src_rank == expr2.src_rank + and expr1.comm_tag == expr2.comm_tag + and expr1.shape == expr2.shape + and expr1.dtype == expr2.dtype + and expr1.tags == expr2.tags + ) + # }}} # vim: fdm=marker diff --git a/pytato/partition.py b/pytato/partition.py index 72012a5..223dc7e 100644 --- a/pytato/partition.py +++ b/pytato/partition.py @@ -25,10 +25,11 @@ THE SOFTWARE. """ from typing import (Any, Callable, Dict, Union, Set, List, Hashable, Tuple, TypeVar, - FrozenSet, Mapping) + FrozenSet, Mapping, Optional, Type) from dataclasses import dataclass +from pytools import memoize_method from pytato.transform import EdgeCachedMapper, CachedWalkMapper from pytato.array import ( Array, AbstractResultWithNamedArrays, Placeholder, @@ -40,10 +41,19 @@ from pytato.target import BoundProgram __doc__ = """ .. autoclass:: GraphPart .. autoclass:: GraphPartition +.. autoclass:: GraphPartitioner .. autoexception:: PartitionInducedCycleError .. autofunction:: find_partition .. autofunction:: execute_partition + + +Internal stuff that is only here because the documentation tool wants it +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. class:: T + + A type variable for :class:`~pytato.array.AbstractResultWithNamedArrays`. """ @@ -54,10 +64,14 @@ PartId = Hashable # {{{ graph partitioner -class _GraphPartitioner(EdgeCachedMapper): +class GraphPartitioner(EdgeCachedMapper): """Given a function *get_part_id*, produces subgraphs representing the computation. Users should not use this class directly, but use :meth:`find_partition` instead. + + .. automethod:: __init__ + .. automethod:: __call__ + .. automethod:: make_partition """ def __init__(self, get_part_id: Callable[[ArrayOrNames], PartId]) -> None: @@ -87,6 +101,8 @@ class _GraphPartitioner(EdgeCachedMapper): # miss any of them. self.seen_part_ids: Set[PartId] = set() + self.pid_to_user_input_names: Dict[PartId, Set[str]] = {} + def get_part_id(self, expr: ArrayOrNames) -> PartId: part_id = self._get_part_id(expr) self.seen_part_ids.add(part_id) @@ -146,6 +162,62 @@ class _GraphPartitioner(EdgeCachedMapper): return super().__call__(expr, *args, **kwargs) + def make_partition(self, outputs: DictOfNamedArrays) -> GraphPartition: + rewritten_outputs = { + name: self(expr) for name, expr in outputs._data.items()} + + pid_to_output_names: Dict[PartId, Set[str]] = { + pid: set() for pid in self.seen_part_ids} + pid_to_input_names: Dict[PartId, Set[str]] = { + pid: set() for pid in self.seen_part_ids} + + var_name_to_result = self.var_name_to_result.copy() + + for out_name, rewritten_output in rewritten_outputs.items(): + out_part_id = self._get_part_id(outputs._data[out_name]) + pid_to_output_names.setdefault(out_part_id, set()).add(out_name) + var_name_to_result[out_name] = rewritten_output + + # Mapping of nodes to their successors; used to compute the topological order + pid_to_needing_pids: Dict[PartId, Set[PartId]] = { + pid: set() for pid in self.seen_part_ids} + pid_to_needed_pids: Dict[PartId, Set[PartId]] = { + pid: set() for pid in self.seen_part_ids} + + for (pid_target, pid_dependency), var_names in \ + self.part_pair_to_edges.items(): + pid_to_needing_pids[pid_dependency].add(pid_target) + pid_to_needed_pids[pid_target].add(pid_dependency) + + for var_name in var_names: + pid_to_output_names[pid_dependency].add(var_name) + pid_to_input_names[pid_target].add(var_name) + + from pytools.graph import compute_topological_order, CycleError + try: + toposorted_part_ids = compute_topological_order(pid_to_needing_pids) + except CycleError: + raise PartitionInducedCycleError + + return GraphPartition( + parts={ + pid: GraphPart( + pid=pid, + needed_pids=frozenset(pid_to_needed_pids[pid]), + user_input_names=frozenset( + self.pid_to_user_input_names.get(pid, set())), + partition_input_names=frozenset(pid_to_input_names[pid]), + output_names=frozenset(pid_to_output_names[pid]), + ) + for pid in self.seen_part_ids}, + var_name_to_result=var_name_to_result, + toposorted_part_ids=toposorted_part_ids) + + def map_placeholder(self, expr: Placeholder, *args: Any) -> Any: + pid = self.get_part_id(expr) + self.pid_to_user_input_names.setdefault(pid, set()).add(expr.name) + return super().map_placeholder(expr) + # }}} @@ -163,19 +235,38 @@ class GraphPart: The IDs of parts that are required to be evaluated before this part can be evaluated. - .. attribute:: input_names + .. attribute:: user_input_names + + A :class:`dict` mapping names to :class:`~pytato.array.Placeholder` + instances that represent input to the computational graph, i.e. were + *not* introduced by partitioning. - Names of placeholders the part requires as input. + .. attribute:: partition_input_names + + Names of placeholders the part requires as input from other parts + in the partition. .. attribute:: output_names Names of placeholders this part provides as output. + + .. attribute:: distributed_sends + + List of :class:`~pytato.distributed.DistributedSend` instances whose + data are in this part. + + .. automethod:: all_input_names """ pid: PartId needed_pids: FrozenSet[PartId] - input_names: FrozenSet[str] + user_input_names: FrozenSet[str] + partition_input_names: FrozenSet[str] output_names: FrozenSet[str] + @memoize_method + def all_input_names(self) -> FrozenSet[str]: + return self.user_input_names | self. partition_input_names + @dataclass(frozen=True) class GraphPartition: @@ -214,10 +305,11 @@ class PartitionInducedCycleError(Exception): """ -# {{{ find_partitions +# {{{ find_partition def find_partition(outputs: DictOfNamedArrays, - part_func: Callable[[ArrayOrNames], PartId]) ->\ + part_func: Callable[[ArrayOrNames], PartId], + partitioner_class: Type[GraphPartitioner] = GraphPartitioner) ->\ GraphPartition: """Partitions the *expr* according to *part_func* and generates code for each partition. Raises :exc:`PartitionInducedCycleError` if the partitioning @@ -235,58 +327,15 @@ def find_partition(outputs: DictOfNamedArrays, where ``A`` and ``C`` are in partition 1, and ``B`` is in partition 2. - :param expr: The expression to partition. + :param outputs: The outputs to partition. :param part_func: A callable that returns an instance of :class:`Hashable` for a node. + :param partitioner_class: A :class:`GraphPartitioner` to + guide the partitioning. :returns: An instance of :class:`GraphPartition` that contains the partition. """ - gp = _GraphPartitioner(part_func) - rewritten_outputs = {name: gp(expr) for name, expr in outputs._data.items()} - - pid_to_output_names: Dict[PartId, Set[str]] = { - pid: set() for pid in gp.seen_part_ids} - pid_to_input_names: Dict[PartId, Set[str]] = { - pid: set() for pid in gp.seen_part_ids} - - var_name_to_result = gp.var_name_to_result.copy() - - for out_name, rewritten_output in rewritten_outputs.items(): - out_part_id = part_func(outputs._data[out_name]) - pid_to_output_names.setdefault(out_part_id, set()).add(out_name) - var_name_to_result[out_name] = rewritten_output - - # Mapping of nodes to their successors; used to compute the topological order - pid_to_needing_pids: Dict[PartId, Set[PartId]] = { - pid: set() for pid in gp.seen_part_ids} - pid_to_needed_pids: Dict[PartId, Set[PartId]] = { - pid: set() for pid in gp.seen_part_ids} - - for (pid_target, pid_dependency), var_names in \ - gp.part_pair_to_edges.items(): - pid_to_needing_pids[pid_dependency].add(pid_target) - pid_to_needed_pids[pid_target].add(pid_dependency) - - for var_name in var_names: - pid_to_output_names[pid_dependency].add(var_name) - pid_to_input_names[pid_target].add(var_name) - - from pytools.graph import compute_topological_order, CycleError - try: - toposorted_part_ids = compute_topological_order(pid_to_needing_pids) - except CycleError: - raise PartitionInducedCycleError - - result = GraphPartition( - parts={ - pid: GraphPart( - pid=pid, - needed_pids=frozenset(pid_to_needed_pids[pid]), - input_names=frozenset(pid_to_input_names[pid]), - output_names=frozenset(pid_to_output_names[pid])) - for pid in gp.seen_part_ids}, - var_name_to_result=var_name_to_result, - toposorted_part_ids=toposorted_part_ids) + result = partitioner_class(part_func).make_partition(outputs) if __debug__: _check_partition_disjointness(result) @@ -320,12 +369,14 @@ def _check_partition_disjointness(partition: GraphPartition) -> None: # Placeholders represent values computed in one partition # and used in one or more other ones. As a result, the # same placeholder may occur in more than one partition. - assert ( - isinstance(my_node, Placeholder) - or my_node not in other_node_set), ( - "partitions not disjoint: " - f"{my_node.__class__.__name__} (id={id(my_node)}) " - f"in both '{part.pid}' and '{other_part_id}'") + if not(isinstance(my_node, Placeholder) + or my_node not in other_node_set): + raise RuntimeError( + "Partitions not disjoint: " + f"{my_node.__class__.__name__} (id={hex(id(my_node))}) " + f"in both '{part.pid}' and '{other_part_id}'" + f"{part.output_names=} " + f"{partition.parts[other_part_id].output_names=} ") part_id_to_nodes[part.pid] = mapper.seen_nodes @@ -339,16 +390,16 @@ def generate_code_for_partition(partition: GraphPartition) \ """Return a mapping of partition identifiers to their :class:`pytato.target.BoundProgram`.""" from pytato import generate_loopy - prg_per_partition = {} + part_id_to_prg = {} for part in partition.parts.values(): d = DictOfNamedArrays( {var_name: partition.var_name_to_result[var_name] for var_name in part.output_names }) - prg_per_partition[part.pid] = generate_loopy(d) + part_id_to_prg[part.pid] = generate_loopy(d) - return prg_per_partition + return part_id_to_prg # }}} @@ -356,7 +407,8 @@ def generate_code_for_partition(partition: GraphPartition) \ # {{{ execute_partitions def execute_partition(partition: GraphPartition, prg_per_partition: - Dict[PartId, BoundProgram], queue: Any) -> Dict[str, Any]: + Dict[PartId, BoundProgram], queue: Any, + input_args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """Executes a set of partitions on a :class:`pyopencl.CommandQueue`. :param parts: An instance of :class:`GraphPartition` representing the @@ -365,11 +417,15 @@ def execute_partition(partition: GraphPartition, prg_per_partition: code on. :returns: A dictionary of variable names mapped to their values. """ - context: Dict[str, Any] = {} + if input_args is None: + input_args = {} + + context: Dict[str, Any] = input_args.copy() + for pid in partition.toposorted_part_ids: part = partition.parts[pid] inputs = { - k: context[k] for k in part.input_names + k: context[k] for k in part.all_input_names() if k in context} _evt, result_dict = prg_per_partition[pid](queue=queue, **inputs) diff --git a/pytato/stringifier.py b/pytato/stringifier.py index e3f76b0..b2fd5bf 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -126,6 +126,8 @@ class Reprifier(Mapper): map_non_contiguous_advanced_index = _map_generic_array map_reshape = _map_generic_array map_einsum = _map_generic_array + map_distributed_recv = _map_generic_array + map_distributed_send_ref_holder = _map_generic_array def map_data_wrapper(self, expr: DataWrapper, depth: int) -> str: if depth > self.truncation_depth: diff --git a/pytato/transform.py b/pytato/transform.py index 863e927..bb531d1 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -2,6 +2,8 @@ from __future__ import annotations __copyright__ = """ Copyright (C) 2020 Matt Wala +Copyright (C) 2020-21 Kaushik Kulkarni +Copyright (C) 2020-21 University of Illinois Board of Trustees """ __license__ = """ @@ -24,10 +26,9 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, - List, Mapping, Iterable, Optional, Tuple) -from pyrsistent.typing import PMap as PMapT + List, Mapping, Iterable, Tuple, Optional, TYPE_CHECKING) from pytato.array import ( Array, IndexLambda, Placeholder, Stack, Roll, @@ -36,10 +37,13 @@ from pytato.array import ( IndexRemappingBase, Einsum, InputArgumentBase, BasicIndex, AdvancedIndexInContiguousAxes, AdvancedIndexInNoncontiguousAxes, IndexBase) + from pytato.loopy import LoopyCall, LoopyCallResult from dataclasses import dataclass from pytato.tags import ImplStored -from pyrsistent import pmap + +if TYPE_CHECKING: + from pytato.distributed import DistributedSendRefHolder, DistributedRecv T = TypeVar("T", Array, AbstractResultWithNamedArrays) CombineT = TypeVar("CombineT") # used in CombineMapper @@ -74,6 +78,7 @@ Dict representation of DAGs .. autoclass:: UsersCollector .. autofunction:: reverse_graph +.. autofunction:: tag_user_nodes .. autofunction:: rec_get_user_nodes Internal stuff that is only here because the documentation tool wants it @@ -306,6 +311,23 @@ class CopyMapper(CachedMapper[ArrayOrNames]): axes=expr.axes, tags=expr.tags) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder) -> Array: + from pytato.distributed import DistributedSend, DistributedSendRefHolder + return DistributedSendRefHolder( + DistributedSend( + data=self.rec(expr.send.data), + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag), + self.rec(expr.passthrough_data)) + + def map_distributed_recv(self, expr: DistributedRecv) -> Array: + from pytato.distributed import DistributedRecv + return DistributedRecv( + src_rank=expr.src_rank, comm_tag=expr.comm_tag, + shape=self.rec_idx_or_size_tuple(expr.shape), + dtype=expr.dtype, tags=expr.tags) + # }}} @@ -400,6 +422,16 @@ class CombineMapper(Mapper, Generic[CombineT]): def map_loopy_call_result(self, expr: LoopyCallResult) -> CombineT: return self.rec(expr._container) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder) -> CombineT: + return self.combine( + self.rec(expr.send.data), + self.rec(expr.passthrough_data), + ) + + def map_distributed_recv(self, expr: DistributedRecv) -> CombineT: + return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) + # }}} @@ -459,6 +491,14 @@ class DependencyMapper(CombineMapper[R]): def map_loopy_call_result(self, expr: LoopyCallResult) -> R: return self.combine(frozenset([expr]), super().map_loopy_call_result(expr)) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder) -> R: + return self.combine( + frozenset([expr]), super().map_distributed_send_ref_holder(expr)) + + def map_distributed_recv(self, expr: DistributedRecv) -> R: + return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) + # }}} @@ -648,6 +688,24 @@ class WalkMapper(Mapper): self.post_visit(expr) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, *args: Any) -> None: + if not self.visit(expr): + return + + self.rec(expr.send.data) + self.rec(expr.passthrough_data) + + self.post_visit(expr) + + def map_distributed_recv(self, expr: DistributedRecv, *args: Any) -> None: + if not self.visit(expr): + return + + self.rec_idx_or_size_tuple(expr.shape) + + self.post_visit(expr) + def map_named_array(self, expr: NamedArray) -> None: if not self.visit(expr): return @@ -1031,7 +1089,9 @@ class UsersCollector(CachedMapper[ArrayOrNames]): def __init__(self) -> None: super().__init__() - self.node_to_users: Dict[ArrayOrNames, Set[ArrayOrNames]] = {} + from pytato.distributed import DistributedSend + self.node_to_users: Dict[ArrayOrNames, + Set[Union[DistributedSend, ArrayOrNames]]] = {} def __call__(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> Any: # Root node has no predecessor @@ -1129,16 +1189,25 @@ class UsersCollector(CachedMapper[ArrayOrNames]): self.node_to_users.setdefault(child, set()).add(expr) self.rec(child) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, *args: Any) -> None: + self.node_to_users.setdefault(expr.passthrough_data, set()).add(expr) + self.rec(expr.passthrough_data) + self.node_to_users.setdefault(expr.send.data, set()).add(expr.send) + self.rec(expr.send.data) + + def map_distributed_recv(self, expr: DistributedRecv, *args: Any) -> None: + self.rec_idx_or_size_tuple(expr, expr.shape) + -def get_users(expr: ArrayOrNames) -> PMapT[ArrayOrNames, - FrozenSet[ArrayOrNames]]: +def get_users(expr: ArrayOrNames) -> Dict[ArrayOrNames, + Set[ArrayOrNames]]: """ Returns a mapping from node in *expr* to its direct users. """ user_collector = UsersCollector() user_collector(expr) - return pmap({node: frozenset(users) - for node, users in user_collector.node_to_users.items()}) + return user_collector.node_to_users # type: ignore # }}} @@ -1168,7 +1237,7 @@ def reverse_graph(graph: Dict[ArrayOrNames, Set[ArrayOrNames]]) \ def _recursively_get_all_users( - direct_users: Mapping[ArrayOrNames, FrozenSet[ArrayOrNames]], + direct_users: Mapping[ArrayOrNames, Set[ArrayOrNames]], node: ArrayOrNames) -> FrozenSet[ArrayOrNames]: result = set() queue = list(direct_users.get(node, set())) @@ -1201,8 +1270,8 @@ def rec_get_user_nodes(expr: ArrayOrNames, return _recursively_get_all_users(users, node) -def tag_child_nodes( - graph: Mapping[ArrayOrNames, FrozenSet[ArrayOrNames]], +def tag_user_nodes( + graph: Mapping[ArrayOrNames, Set[ArrayOrNames]], tag: Any, starting_point: ArrayOrNames, node_to_tags: Optional[Dict[ArrayOrNames, Set[ArrayOrNames]]] = None @@ -1214,12 +1283,11 @@ def tag_child_nodes( use case for this function is the graph in :attr:`UsersCollector.node_to_users`. :param tag: The value to tag the nodes with. - :param starting_point: An optional starting point in *graph*. + :param starting_point: A starting point in *graph*. :param node_to_tags: The resulting mapping of nodes to tags. - :returns: the updated value of *node_to_tags*. """ from warnings import warn - warn("tag_child_nodes is set for deprecation in June, 2022", + warn("tag_user_nodes is set for deprecation in June, 2022", DeprecationWarning) if node_to_tags is None: @@ -1237,7 +1305,7 @@ def tag_child_nodes( # {{{ EdgeCachedMapper -class EdgeCachedMapper(CachedMapper[ArrayOrNames], ABC): +class EdgeCachedMapper(CachedMapper[ArrayOrNames]): """ Mapper class to execute a rewriting method (:meth:`handle_edge`) on each edge in the graph. @@ -1384,6 +1452,24 @@ class EdgeCachedMapper(CachedMapper[ArrayOrNames], ABC): for name, child in expr.bindings.items()}, ) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, *args: Any) -> \ + DistributedSendRefHolder: + from pytato.distributed import DistributedSendRefHolder + return DistributedSendRefHolder( + send=self.handle_edge(expr, expr.send.data), + passthrough_data=self.handle_edge(expr, expr.passthrough_data), + tags=expr.tags + ) + + def map_distributed_recv(self, expr: DistributedRecv, *args: Any) \ + -> Any: + from pytato.distributed import DistributedRecv + return DistributedRecv( + src_rank=expr.src_rank, comm_tag=expr.comm_tag, + shape=self.rec_idx_or_size_tuple(expr, expr.shape, *args), + dtype=expr.dtype, tags=expr.tags, axes=expr.axes) + # }}} # }}} diff --git a/pytato/visualization.py b/pytato/visualization.py index cb7c48c..d139586 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -29,7 +29,9 @@ THE SOFTWARE. import contextlib import dataclasses import html -from typing import Callable, Dict, Union, Iterator, List, Mapping, Hashable + +from typing import (TYPE_CHECKING, Callable, Dict, Union, Iterator, List, + Mapping, Hashable) from pytools import UniqueNameGenerator from pytools.codegen import CodeGenerator as CodeGeneratorBase @@ -44,6 +46,10 @@ from pytato.codegen import normalize_outputs from pytato.transform import CachedMapper, ArrayOrNames from pytato.partition import GraphPartition +from pytato.distributed import DistributedGraphPart + +if TYPE_CHECKING: + from pytato.distributed import DistributedSendRefHolder __doc__ = """ @@ -80,7 +86,7 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" -class ArrayToDotNodeInfoMapper(CachedMapper[Array]): +class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): def __init__(self) -> None: super().__init__() self.nodes: Dict[ArrayOrNames, DotNodeInfo] = {} @@ -111,8 +117,7 @@ class ArrayToDotNodeInfoMapper(CachedMapper[Array]): info.edges[field] = attr elif isinstance(attr, AbstractResultWithNamedArrays): - # type-ignore-reason: incompatible with superclass - self.rec(attr) # type: ignore[arg-type] + self.rec(attr) info.edges[field] = attr elif isinstance(attr, tuple): @@ -189,6 +194,19 @@ class ArrayToDotNodeInfoMapper(CachedMapper[Array]): entrypoint=expr.entrypoint), edges=edges) + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder) -> None: + + info = self.get_common_dot_info(expr) + + self.rec(expr.passthrough_data) + info.edges["passthrough"] = expr.passthrough_data + + self.rec(expr.send.data) + info.edges["sent"] = expr.send.data + + self.nodes[expr] = info + def dot_escape(s: str) -> str: # "\" and HTML are significant in graphviz. @@ -205,22 +223,22 @@ class DotEmitter(CodeGeneratorBase): self("}") -def _emit_array(emit: DotEmitter, info: DotNodeInfo, id: str, - color: str = "white") -> None: +def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str], + dot_node_id: str, color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' rows = ['%s' - % (td_attrib, dot_escape(info.title))] + % (td_attrib, dot_escape(title))] - for name, field in info.fields.items(): + for name, field in fields.items(): field_content = dot_escape(field).replace("\n", "
") rows.append( f"{dot_escape(name)}:" f"{field_content}" ) table = "\n%s
" % (table_attrib, "".join(rows)) - emit("%s [label=<%s> style=filled fillcolor=%s]" % (id, table, color)) + emit("%s [label=<%s> style=filled fillcolor=%s]" % (dot_node_id, table, color)) def _emit_name_cluster(emit: DotEmitter, names: Mapping[str, ArrayOrNames], @@ -282,11 +300,13 @@ def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: with emit.block("subgraph cluster_Inputs"): emit('label="Inputs"') for array in input_arrays: - _emit_array(emit, nodes[array], array_to_id[array]) + _emit_array(emit, + nodes[array].title, nodes[array].fields, array_to_id[array]) # Emit non-inputs. for array in internal_arrays: - _emit_array(emit, nodes[array], array_to_id[array]) + _emit_array(emit, + nodes[array].title, nodes[array].fields, array_to_id[array]) # Emit edges. for array, node in nodes.items(): @@ -334,6 +354,26 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: # Second pass: emit the graph. for part in partition.parts.values(): + # {{{ emit receives nodes if distributed + + if isinstance(part, DistributedGraphPart): + part_dist_recv_var_name_to_node_id = {} + for name, recv in ( + part.input_name_to_recv_node.items()): + node_id = id_gen("recv") + _emit_array(emit, "Recv", { + "shape": stringify_shape(recv.shape), + "dtype": str(recv.dtype), + "src_rank": str(recv.src_rank), + "comm_tag": str(recv.comm_tag), + }, node_id) + + part_dist_recv_var_name_to_node_id[name] = node_id + else: + part_dist_recv_var_name_to_node_id = {} + + # }}} + part_node_to_info = part_id_to_node_info[part.pid] input_arrays: List[Array] = [] internal_arrays: List[ArrayOrNames] = [] @@ -355,13 +395,27 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: # Non-Placeholders are emitted *inside* their subgraphs below. if isinstance(array, Placeholder): if array not in emitted_placeholders: - _emit_array(emit, part_node_to_info[array], + _emit_array(emit, + part_node_to_info[array].title, + part_node_to_info[array].fields, array_to_id[array], "deepskyblue") - emitted_placeholders.add(array) # Emit cross-partition edges - tgt = array_to_id[partition.var_name_to_result[array.name]] - emit(f"{tgt} -> {array_to_id[array]} [style=dashed]") + if array.name in part_dist_recv_var_name_to_node_id: + tgt = part_dist_recv_var_name_to_node_id[array.name] + emit(f"{tgt} -> {array_to_id[array]} [style=dotted]") + emitted_placeholders.add(array) + elif array.name in part.user_input_names: + # These are placeholders for external input. They + # are cleanly associated with a single partition + # and thus emitted below. + pass + else: + # placeholder for a value from a different partition + tgt = array_to_id[ + partition.var_name_to_result[array.name]] + emit(f"{tgt} -> {array_to_id[array]} [style=dashed]") + emitted_placeholders.add(array) # }}} @@ -370,13 +424,42 @@ def get_dot_graph_from_partition(partition: GraphPartition) -> str: emit(f'label="{part.pid}"') for array in input_arrays: - if not isinstance(array, Placeholder): - _emit_array(emit, part_node_to_info[array], + if (not isinstance(array, Placeholder) + or array.name in part.user_input_names): + _emit_array(emit, + part_node_to_info[array].title, + part_node_to_info[array].fields, array_to_id[array], "deepskyblue") # Emit internal nodes for array in internal_arrays: - _emit_array(emit, part_node_to_info[array], array_to_id[array]) + _emit_array(emit, + part_node_to_info[array].title, + part_node_to_info[array].fields, + array_to_id[array]) + + # {{{ emit send nodes if distributed + + deferred_send_edges = [] + if isinstance(part, DistributedGraphPart): + for name, send in ( + part.output_name_to_send_node.items()): + node_id = id_gen("send") + _emit_array(emit, "Send", { + "dest_rank": str(send.dest_rank), + "comm_tag": str(send.comm_tag), + }, node_id) + + deferred_send_edges.append( + f"{array_to_id[send.data]} -> {node_id}" + f'[style=dotted, label="{dot_escape(name)}"]') + + # }}} + + # If an edge is emitted in a subgraph, it drags its nodes into the + # subgraph, too. Not what we want. + for edge in deferred_send_edges: + emit(edge) # Emit intra-partition edges for array, node in part_node_to_info.items(): diff --git a/test/test_distributed.py b/test/test_distributed.py new file mode 100644 index 0000000..3ac44ad --- /dev/null +++ b/test/test_distributed.py @@ -0,0 +1,221 @@ +__copyright__ = """Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl as pytest_generate_tests) +import pyopencl as cl +import numpy as np +import pytato as pt +import sys +import os + +from pytato.distributed import (staple_distributed_send, make_distributed_recv, + find_distributed_partition, + execute_distributed_partition, number_distributed_tags) + +from pytato.partition import generate_code_for_partition + + +# {{{ mpi test infrastructure + +def run_test_with_mpi(num_ranks, f, *args): + import pytest + pytest.importorskip("mpi4py") + + from pickle import dumps + from base64 import b64encode + + invocation_info = b64encode(dumps((f, args))).decode() + from subprocess import check_call + + # NOTE: CI uses OpenMPI; -x to pass env vars. MPICH uses -env + check_call([ + "mpiexec", "-np", str(num_ranks), + "-x", "RUN_WITHIN_MPI=1", + "-x", f"INVOCATION_INFO={invocation_info}", + sys.executable, __file__]) + + +def run_test_with_mpi_inner(): + from pickle import loads + from base64 import b64decode + f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode())) + + f(cl.create_some_context, *args) + +# }}} + + +# {{{ "basic" test (similar to distributed example) + +def test_distributed_execution_basic(): + run_test_with_mpi(2, _do_test_distributed_execution_basic) + + +def _do_test_distributed_execution_basic(ctx_factory): + from mpi4py import MPI # pylint: disable=import-error + comm = MPI.COMM_WORLD + + rank = comm.Get_rank() + size = comm.Get_size() + + rng = np.random.default_rng(seed=27) + + x_in = rng.integers(100, size=(4, 4)) + x = pt.make_data_wrapper(x_in) + + halo = staple_distributed_send(x, dest_rank=(rank-1) % size, comm_tag=42, + stapled_to=make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=42, shape=(4, 4), dtype=int)) + + y = x+halo + + # Find the partition + outputs = pt.DictOfNamedArrays({"out": y}) + distributed_parts = find_distributed_partition(outputs) + prg_per_partition = generate_code_for_partition(distributed_parts) + + # Execute the distributed partition + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + + context = execute_distributed_partition(distributed_parts, prg_per_partition, + queue, comm) + + final_res = context["out"].get(queue) + + # All ranks generate the same random numbers (same seed). + np.testing.assert_allclose(x_in*2, final_res) + +# }}} + + +# {{{ test based on random dag + +def test_distributed_execution_random_dag(): + run_test_with_mpi(2, _do_test_distributed_execution_random_dag) + + +class _RandomDAGTag: + pass + + +def _do_test_distributed_execution_random_dag(ctx_factory): + from mpi4py import MPI # pylint: disable=import-error + comm = MPI.COMM_WORLD + + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + rank = comm.Get_rank() + size = comm.Get_size() + + from testlib import RandomDAGContext, make_random_dag + + axis_len = 4 + comm_fake_prob = 500 + + gen_comm_called = False + + ntests = 10 + for i in range(ntests): + seed = 120 + i + print(f"Step {i} {seed}") + + # {{{ compute value with communication + + comm_tag = 17 + + def gen_comm(rdagc): + nonlocal gen_comm_called + gen_comm_called = True + + nonlocal comm_tag + comm_tag += 1 + tag = (comm_tag, _RandomDAGTag) + + inner = make_random_dag(rdagc) + return staple_distributed_send( + inner, dest_rank=(rank-1) % size, comm_tag=tag, + stapled_to=make_distributed_recv( + src_rank=(rank+1) % size, comm_tag=tag, + shape=inner.shape, dtype=inner.dtype)) + + rdagc_comm = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False, + additional_generators=[ + (comm_fake_prob, gen_comm) + ]) + x_comm = make_random_dag(rdagc_comm) + + distributed_partition = find_distributed_partition( + pt.DictOfNamedArrays({"result": x_comm})) + + # Transform symbolic tags into numeric ones for MPI + distributed_partition, _new_mpi_base_tag = number_distributed_tags( + comm, + distributed_partition, + base_tag=comm_tag) + + prg_per_partition = generate_code_for_partition(distributed_partition) + + context = execute_distributed_partition( + distributed_partition, prg_per_partition, queue, comm) + + res_comm = context["result"] + + # }}} + + # {{{ compute ref value without communication + + # compiled evaluation (i.e. use_numpy=False) fails for some of these + # graphs, cf. https://github.com/inducer/pytato/pull/255 + rdagc_no_comm = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=True, + additional_generators=[ + (comm_fake_prob, lambda rdagc: make_random_dag(rdagc)) + ]) + res_no_comm_numpy = make_random_dag(rdagc_no_comm) + + # }}} + + if not isinstance(res_comm, np.ndarray): + res_comm = res_comm.get(queue=queue) + + np.testing.assert_allclose(res_comm, res_no_comm_numpy) + + assert gen_comm_called + +# }}} + + +if __name__ == "__main__": + if "RUN_WITHIN_MPI" in os.environ: + run_test_with_mpi_inner() + elif len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) + +# vim: foldmethod=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 43b8188..f85bff9 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -638,7 +638,7 @@ def test_rec_get_user_nodes_linear_complexity(): assert (expected_result == result) -def test_tag_child_nodes_linear_complexity(): +def test_tag_user_nodes_linear_complexity(): from numpy.random import default_rng def construct_intestine_graph(depth=100, seed=0): @@ -663,7 +663,7 @@ def test_tag_child_nodes_linear_complexity(): expected_result[expr] = {"foo"} expr, inp = construct_intestine_graph() - result = pt.transform.tag_child_nodes(user_collector.node_to_users, "foo", inp) + result = pt.transform.tag_user_nodes(user_collector.node_to_users, "foo", inp) ExpectedResultComputer()(expr) assert expected_result == result -- GitLab