diff --git a/pytato/__init__.py b/pytato/__init__.py index aaa237d6b118de352996fcbd451660c5053f450e..b80a55e28af8b09d63fc534b0ae47fcb1fb23256 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -75,7 +75,8 @@ from pytato.distributed import (make_distributed_send, make_distributed_recv, staple_distributed_send, find_distributed_partition, number_distributed_tags, - execute_distributed_partition) + execute_distributed_partition, + ) from pytato.partition import generate_code_for_partition diff --git a/pytato/distributed.py b/pytato/distributed.py index 24db4faed8847e4df40a5d8839f4045dc87c5971..790dfea982a8c09efc3306cf7e7c09d44ef27445 100644 --- a/pytato/distributed.py +++ b/pytato/distributed.py @@ -25,19 +25,28 @@ THE SOFTWARE. """ from typing import (Any, Dict, Hashable, Tuple, Optional, Set, # noqa: F401 - List, FrozenSet, Callable, cast, Mapping) # Mapping required by sphinx + List, FrozenSet, Callable, cast, Mapping, Iterable + ) # Mapping required by sphinx +from pyrsistent.typing import PMap as PMapT from dataclasses import dataclass from pytools import UniqueNameGenerator -from pytools.tag import Taggable, TagsType +from pytools.tag import Taggable, TagsType, UniqueTag from pytato.array import (Array, _SuppliedShapeAndDtypeMixin, - DictOfNamedArrays, ShapeType, - Placeholder, make_placeholder, - _get_default_axes, AxesT) -from pytato.transform import ArrayOrNames, CopyMapper + DictOfNamedArrays, ShapeType, Placeholder, + make_placeholder, _get_default_axes, AxesT, + NamedArray) +from pytato.transform import (ArrayOrNames, CopyMapper, Mapper, + CachedWalkMapper, CopyMapperWithExtraArgs, + CombineMapper) + from pytato.partition import GraphPart, GraphPartition, PartId, GraphPartitioner from pytato.target import BoundProgram +from pytato.analysis import DirectPredecessorsGetter +from pyrsistent import pmap +from functools import cached_property +from pytato.scalar_expr import SCALAR_CLASSES import numpy as np @@ -208,6 +217,20 @@ class DistributedSendRefHolder(Array): def dtype(self) -> np.dtype[Any]: return self.passthrough_data.dtype + def copy(self, **kwargs: Any) -> DistributedSendRefHolder: + # override 'Array.copy' because + # 'DistributedSendRefHolder.axes' is a read-only field. + send = kwargs.pop("send", self.send) + passthrough_data = kwargs.pop("passthrough_data", self.passthrough_data) + tags = kwargs.pop("tags", self.tags) + + if kwargs: + raise ValueError("Cannot assign" + f" DistributedSendRefHolder.'{set(kwargs)}'") + return DistributedSendRefHolder(send, + passthrough_data, + tags) + class DistributedRecv(_SuppliedShapeAndDtypeMixin, Array): """Class representing a distributed receive operation. @@ -411,13 +434,7 @@ def _gather_distributed_comm_info(partition: GraphPartition, # }}} -# {{{ find distributed partition - -@dataclass(frozen=True, eq=True) -class DistributedPartitionId(): - fed_sends: object - feeding_recvs: object - +# {{{ find_distributed_partition class _DistributedGraphPartitioner(GraphPartitioner): @@ -446,64 +463,473 @@ class _DistributedGraphPartitioner(GraphPartitioner): return _gather_distributed_comm_info(partition, self.pid_to_dist_sends) -def find_distributed_partition( - outputs: DictOfNamedArrays) -> DistributedGraphPartition: - """Finds a partitioning in a distributed context.""" +class _MandatoryPartitionOutputsCollector(CombineMapper[FrozenSet[Array]]): + """ + Collects all nodes that, after partitioning, are necessarily outputs + of the partition to which they belong. + """ + def __init__(self) -> None: + super().__init__() + self.partition_outputs: Set[Array] = set() - from pytato.transform import (UsersCollector, TopoSortMapper, - reverse_graph, tag_user_nodes) + def combine(self, *args: FrozenSet[Array]) -> FrozenSet[Array]: + from functools import reduce + return reduce(frozenset.union, args, frozenset()) - gdm = UsersCollector() - gdm(outputs) + def map_distributed_send_ref_holder(self, + expr: DistributedSendRefHolder + ) -> FrozenSet[Array]: + return self.combine(frozenset([expr.send.data]), + super().map_distributed_send_ref_holder(expr)) - graph = gdm.node_to_users + def _map_input_base(self, expr: Array) -> FrozenSet[Array]: + return frozenset() - # type-ignore-reason: - # 'graph' also includes DistributedSend nodes, which are not Arrays - rev_graph = reverse_graph(graph) # type: ignore[arg-type] + map_placeholder = _map_input_base + map_data_wrapper = _map_input_base + map_size_param = _map_input_base + map_distributed_recv = _map_input_base - # FIXME: Inefficient... too many traversals - node_to_feeding_recvs: Dict[ArrayOrNames, Set[ArrayOrNames]] = {} - for node in graph: - node_to_feeding_recvs.setdefault(node, set()) - if isinstance(node, DistributedRecv): - tag_user_nodes(graph, tag=node, # type: ignore[arg-type] - starting_point=node, - node_to_tags=node_to_feeding_recvs) - node_to_fed_sends: Dict[ArrayOrNames, Set[ArrayOrNames]] = {} - for node in rev_graph: - node_to_fed_sends.setdefault(node, set()) - if isinstance(node, DistributedSend): - tag_user_nodes(rev_graph, tag=node, starting_point=node, - node_to_tags=node_to_fed_sends) +class _MaterializedArrayCollector(CachedWalkMapper): + """ + Collects all nodes that have to be materialized during code-generation. + """ + def __init__(self) -> None: + super().__init__() + self.materialized_arrays: Set[Array] = set() + + def post_visit(self, expr: Any) -> None: + from pytato.tags import ImplStored + from pytato.loopy import LoopyCallResult + + if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): + self.materialized_arrays.add(expr) + + if isinstance(expr, LoopyCallResult): + self.materialized_arrays.add(expr) + from pytato.loopy import LoopyCall + assert isinstance(expr._container, LoopyCall) + for _, subexpr in sorted(expr._container.bindings.items()): + if isinstance(subexpr, Array): + self.materialized_arrays.add(subexpr) + else: + assert isinstance(subexpr, SCALAR_CLASSES) - def get_part_id(expr: ArrayOrNames) -> DistributedPartitionId: - return DistributedPartitionId(frozenset(node_to_fed_sends[expr]), - frozenset(node_to_feeding_recvs[expr])) + if isinstance(expr, DictOfNamedArrays): + for _, subexpr in sorted(expr._data.items()): + assert isinstance(subexpr, Array) + self.materialized_arrays.add(subexpr) - # {{{ Sanity checks - if __debug__: - for node, _ in node_to_feeding_recvs.items(): - for n in node_to_feeding_recvs[node]: - assert(isinstance(n, DistributedRecv)) +class _DominantMaterializedPredecessorsCollector(Mapper): + """ + A Mapper whose mapper method for a node returns the materialized predecessors + just after the point the node is evaluated. + """ + def __init__(self, materialized_arrays: FrozenSet[Array]) -> None: + super().__init__() + self.materialized_arrays = materialized_arrays + self.cache: Dict[ArrayOrNames, FrozenSet[Array]] = {} + + def _combine(self, values: Iterable[Array]) -> FrozenSet[Array]: + from functools import reduce + return reduce(frozenset.union, + (self.rec(v) for v in values), + frozenset()) + + # type-ignore reason: return type not compatible with Mapper.rec's type + def rec(self, expr: ArrayOrNames) -> FrozenSet[Array]: # type: ignore[override] + try: + return self.cache[expr] + except KeyError: + # type-ignore reason: type not compatible with super.rec() type + result: FrozenSet[Array] = super().rec(expr) # type: ignore[type-var] + self.cache[expr] = result + return result + + @cached_property + def direct_preds_getter(self) -> DirectPredecessorsGetter: + return DirectPredecessorsGetter() + + def _map_generic_node(self, expr: Array) -> FrozenSet[Array]: + direct_preds = self.direct_preds_getter(expr) + + if expr in self.materialized_arrays: + return frozenset([expr]) + else: + return self._combine(direct_preds) + + map_placeholder = _map_generic_node + map_data_wrapper = _map_generic_node + map_size_param = _map_generic_node + + map_index_lambda = _map_generic_node + map_stack = _map_generic_node + map_concatenate = _map_generic_node + map_roll = _map_generic_node + map_axis_permutation = _map_generic_node + map_basic_index = _map_generic_node + map_contiguous_advanced_index = _map_generic_node + map_non_contiguous_advanced_index = _map_generic_node + map_reshape = _map_generic_node + map_einsum = _map_generic_node + map_distributed_recv = _map_generic_node + + def map_named_array(self, expr: NamedArray) -> FrozenSet[Array]: + raise NotImplementedError("only LoopyCallResult named array" + " supported for now.") + + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays + ) -> FrozenSet[Array]: + raise NotImplementedError("Dict of named arrays not (yet) implemented") + + def map_loopy_call_result(self, expr: NamedArray) -> FrozenSet[Array]: + # ``loopy call result` is always materialized. However, make sure to + # traverse its arguments. + assert expr in self.materialized_arrays + return self._map_generic_node(expr) + + def map_distributed_send_ref_holder(self, + expr: DistributedSendRefHolder + ) -> FrozenSet[Array]: + return self.rec(expr.passthrough_data) + + +class _DominantMaterializedPredecessorsRecorder(CachedWalkMapper): + """ + For each node in an expression graph, this mapper records the dominant + predecessors of each node of an expression graph into + :attr:`array_to_mat_preds`. + """ + def __init__(self, mat_preds_getter: Callable[[Array], FrozenSet[Array]] + ) -> None: + super().__init__() + self.mat_preds_getter = mat_preds_getter + self.array_to_mat_preds: Dict[Array, FrozenSet[Array]] = {} - for node, _ in node_to_fed_sends.items(): - for n in node_to_fed_sends[node]: - assert(isinstance(n, DistributedSend)) + @cached_property + def direct_preds_getter(self) -> DirectPredecessorsGetter: + return DirectPredecessorsGetter() - tm = TopoSortMapper() - tm(outputs) + def post_visit(self, expr: Any) -> None: + from functools import reduce + if isinstance(expr, Array): + self.array_to_mat_preds[expr] = reduce( + frozenset.union, + (self.mat_preds_getter(pred) + for pred in self.direct_preds_getter(expr)), + frozenset()) - for node in tm.topological_order: - get_part_id(node) - # }}} +def _linearly_schedule_batches( + predecessors: PMapT[Array, FrozenSet[Array]]) -> PMapT[Array, int]: + """ + Used by :func:`find_distributed_partition`. Based on the dependencies in + *predecessors*, each node is assigned a time such that evaluating the array + at that point in time would not violate dependencies. This "time" or "batch + number" is then used as a partition ID. + """ + from functools import reduce + current_time = 0 + ary_to_time = {} + scheduled_nodes: Set[Array] = set() + all_nodes = frozenset(predecessors) + + # assert that the keys contain all the nodes + assert reduce(frozenset.union, + predecessors.values(), + cast(FrozenSet[Array], frozenset())) <= frozenset(predecessors) + + while len(scheduled_nodes) < len(all_nodes): + # {{{ eagerly schedule nodes whose predecessors have been scheduled + + nodes_to_schedule = {node + for node, preds in predecessors.items() + if ((node not in scheduled_nodes) + and (preds <= scheduled_nodes))} + for node in nodes_to_schedule: + assert node not in ary_to_time + ary_to_time[node] = current_time + + scheduled_nodes.update(nodes_to_schedule) + + current_time += 1 + + # }}} + + return pmap(ary_to_time) + + +def _assign_materialized_arrays_to_part_id( + materialized_arrays: FrozenSet[Array], + array_to_output_deps: Mapping[Array, FrozenSet[Array]], + outputs_to_part_id: Mapping[Array, int] +) -> PMapT[Array, int]: + """ + Returns a mapping from a materialized array to the part's ID where all the + inputs of the array expression are available. + + Invoked as an intermediate step in :func:`find_distributed_partition`. + + .. note:: + + In this heuristic we compute the materialized array as soon as its + inputs are available. In some cases it might be worth exploring + schedules where the evaluation of an array is delayed until one of + its users demand it. + """ + + materialized_array_to_part_id: Dict[Array, int] = {} + + for ary in materialized_arrays: + materialized_array_to_part_id[ary] = max( + (outputs_to_part_id[dep] + for dep in array_to_output_deps[ary]), + default=-1) + 1 + + return pmap(materialized_array_to_part_id) + + +def _get_array_to_dominant_materialized_deps( + outputs: DictOfNamedArrays, + materialized_arrays: FrozenSet[Array]) -> PMapT[Array, FrozenSet[Array]]: + """ + Returns a mapping from each node in the DAG *outputs* to a :class:`frozenset` + of its dominant materialized predecessors. + """ + + dominant_materialized_deps = _DominantMaterializedPredecessorsCollector( + materialized_arrays) + dominant_materialized_deps_recorder = ( + _DominantMaterializedPredecessorsRecorder(dominant_materialized_deps)) + dominant_materialized_deps_recorder(outputs) + return pmap(dominant_materialized_deps_recorder.array_to_mat_preds) + +def _get_materialized_arrays_promoted_to_partition_outputs( + ary_to_dominant_stored_preds: Mapping[Array, FrozenSet[Array]], + stored_ary_to_part_id: Mapping[Array, int], + materialized_arrays: FrozenSet[Array] +) -> FrozenSet[Array]: + """ + Returns a :class:`frozenset` of materialized arrays that are used by + multiple partitions. Materialized arrays that are used by multiple + partitions are special in that they *must* be promoted as the outputs of a + partition. + + Invoked as an intermediate step in :func:`find_distributed_partition`. + + :arg ary_to_dominant_stored_preds: A mapping from array to the dominant + stored predecessors. A stored array can be either a mandatory partition + output or a materialized array as indicated by the user. + """ + materialized_ary_to_part_id_users: Dict[Array, Set[int]] = {} + + for ary in stored_ary_to_part_id: + stored_preds = ary_to_dominant_stored_preds[ary] + for pred in stored_preds: + if pred in materialized_arrays: + (materialized_ary_to_part_id_users + .setdefault(pred, set()) + .add(stored_ary_to_part_id[ary])) + + return frozenset({ary + for ary, users in materialized_ary_to_part_id_users.items() + if users != {stored_ary_to_part_id[ary]}}) + + +@dataclass(frozen=True, eq=True, repr=True) +class PartIDTag(UniqueTag): + """ + A tag applicable to a :class:`pytato.Array` recording to which part the + array belongs. + """ + part_id: int + + +class _PartIDTagAssigner(CopyMapperWithExtraArgs): + """ + Used by :func:`find_distributed_partition` to assign each array + node a :class:`PartIDTag`. + """ + def __init__(self, + stored_array_to_part_id: Mapping[Array, int], + partition_outputs: FrozenSet[Array]) -> None: + self.stored_array_to_part_id = stored_array_to_part_id + self.partition_outputs = partition_outputs + + # type-ignore reason: incompatible attribute type wrt base. + self._cache: Dict[Tuple[ArrayOrNames, int], + Any] = {} # type: ignore[assignment] + + # type-ignore-reason: incompatible with super class + def cache_key(self, # type: ignore[override] + expr: ArrayOrNames, + user_part_id: int + ) -> Tuple[ArrayOrNames, int]: + + return (expr, user_part_id) + + # type-ignore-reason: incompatible with super class + def rec(self, # type: ignore[override] + expr: ArrayOrNames, + user_part_id: int) -> Any: + key = self.cache_key(expr, user_part_id) + try: + return self._cache[key] + except KeyError: + if isinstance(expr, Array): + if expr in self.stored_array_to_part_id: + assert ((self.stored_array_to_part_id[expr] + == user_part_id) + or expr in self.partition_outputs) + # at stored array the part id changes + user_part_id = self.stored_array_to_part_id[expr] + + expr = expr.tagged(PartIDTag(user_part_id)) + + result = super().rec(expr, user_part_id) + self._cache[key] = result + return result + + +def find_distributed_partition(outputs: DictOfNamedArrays + ) -> DistributedGraphPartition: + """ + Partitions *outputs* into parts. Between two parts communication + statements (sends/receives) are scheduled. + + .. note:: + + The partitioning of a DAG generally does not have a unique solution. + The heuristic employed by this partitioner is as follows: + + 1. The data contained in :class:`~pytato.DistributedSend` are marked as + *mandatory part outputs*. + 2. Based on the dependencies in *outputs*, a DAG is constructed with + only the mandatory part outputs as the nodes. + 3. Using a topological sort the mandatory part outputs are assigned a + "time" (an integer) such that evaluating these outputs at that time + would preserve dependencies. We maximize the number of part outputs + scheduled at a each "time". This requirement ensures our topological + sort is deterministic. + 4. We then turn our attention to the other arrays that are allocated to a + buffer. These are the materialized arrays and belong to one of the + following classes: + - An :class:`~pytato.Array` tagged with :class:`pytato.tags.ImplStored`. + - The expressions in a :class:`~pytato.DictOfNamedArrays`. + 5. Based on *outputs*, we compute the predecessors of a materialized + that were a part of the mandatory part outputs. A materialized array + is scheduled to be evaluated in a part as soon as all of its inputs + are available. Note that certain inputs (like + :class:`~pytato.DistributedRecv`) might not be available until + certain mandatory part outputs have been evaluated. + 6. From *outputs*, we can construct a DAG comprising only of mandatory + part outputs and materialized arrays. We mark all materialized + arrays that are being used by nodes in a part that's not the one in + which the materialized array itself was evaluated. Such materialized + arrays are also realized as part outputs. This is done to avoid + recomputations. + + Knobs to tweak the partition: + + 1. By removing dependencies between the mandatory part outputs, the + resulting DAG would lead to fewer number of parts and parts with + more number of nodes in them. Similarly, adding dependencies between + the part outputs would lead to smaller parts. + 2. Tagging nodes with :class:~pytato.tags.ImplStored` would help in + avoiding re-computations. + """ + from pytato.transform import SubsetDependencyMapper + from pytato.array import make_dict_of_named_arrays from pytato.partition import find_partition + + # {{{ get partitioning helper data corresponding to the DAG + + partition_outputs = _MandatoryPartitionOutputsCollector()(outputs) + + # materialized_arrays: "extra" arrays that must be stored in a buffer + materialized_arrays_collector = _MaterializedArrayCollector() + materialized_arrays_collector(outputs) + materialized_arrays = frozenset( + materialized_arrays_collector.materialized_arrays) - partition_outputs + + # }}} + + dep_mapper = SubsetDependencyMapper(partition_outputs) + + # {{{ compute a dependency graph between outputs, schedule and partition + + output_to_deps = pmap({partition_out: (dep_mapper(partition_out) + - frozenset([partition_out])) + for partition_out in partition_outputs}) + + output_to_part_id = _linearly_schedule_batches(output_to_deps) + + # }}} + + # {{{ assign each materialized array a partition ID in which it will be placed + + materialized_array_to_output_deps = pmap({ary: (dep_mapper(ary) + - frozenset([ary])) + for ary in materialized_arrays}) + materialized_ary_to_part_id = _assign_materialized_arrays_to_part_id( + materialized_arrays, + materialized_array_to_output_deps, + output_to_part_id) + + assert frozenset(materialized_ary_to_part_id) == materialized_arrays + + # }}} + + stored_ary_to_part_id = materialized_ary_to_part_id.update(output_to_part_id) + + # {{{ find which materialized arrays have users in multiple parts + # (and promote them to part outputs) + + ary_to_dominant_materialized_deps = ( + _get_array_to_dominant_materialized_deps(outputs, + (materialized_arrays + | partition_outputs))) + + materialized_arrays_realized_as_partition_outputs = ( + _get_materialized_arrays_promoted_to_partition_outputs( + ary_to_dominant_materialized_deps, + stored_ary_to_part_id, + materialized_arrays)) + + # }}} + + # {{{ tag each node with its part ID + + # Why is this necessary? (I.e. isn't the mapping *stored_ary_to_part_id* enough?) + # By assigning tags we also duplicate the non-materialized nodes that are + # to be made available in multiple parts. Parts being disjoint is one of + # the requirements of *find_partition*. + part_id_tag_assigner = _PartIDTagAssigner( + stored_ary_to_part_id, + partition_outputs | materialized_arrays_realized_as_partition_outputs) + + partitioned_outputs = make_dict_of_named_arrays({ + name: part_id_tag_assigner(subexpr, stored_ary_to_part_id[subexpr]) + for name, subexpr in outputs._data.items()}) + + # }}} + + def get_part_id(expr: ArrayOrNames) -> int: + if not isinstance(expr, Array): + raise NotImplementedError("find_distributed_partition" + " cannot partition DictOfNamedArrays") + assert isinstance(expr, Array) + tag, = expr.tags_of_type(PartIDTag) + return tag.part_id + return cast(DistributedGraphPartition, - find_partition(outputs, get_part_id, _DistributedGraphPartitioner)) + find_partition(partitioned_outputs, + get_part_id, + _DistributedGraphPartitioner) + ) # }}}