From 9dcba0ba6a4242db7acee3cfa2993f9a19357d7b Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 20 Mar 2022 13:45:36 -0500 Subject: [PATCH] MPMS materialization: support distributed nodes Co-authored-by: Andreas Kloeckner --- pytato/analysis/__init__.py | 21 +++++++++++++++++++++ pytato/transform.py | 22 ++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 83ec104..3b21b6e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -47,6 +47,13 @@ class NUserCollector(Mapper): """ A :class:`pytato.transform.CachedWalkMapper` that records the number of times an array expression is a direct dependency of other nodes. + + .. note:: + + - We do not consider the :class:`pytato.DistributedSendRefHolder` + a user of :attr:`pytato.DistributedSendRefHolder.send`. This is + because in a data flow sense, the send-ref holder does not use the + send's data. """ def __init__(self) -> None: from collections import defaultdict @@ -141,6 +148,20 @@ class NUserCollector(Mapper): map_data_wrapper = _map_input_base map_size_param = _map_input_base + def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder + ) -> None: + # Note: We do not consider 'expr.send.data' as a predecessor of *expr*, + # as there is no dataflow from *expr.send.data* to *expr* + self.nusers[expr.passthrough_data] += 1 + self.rec(expr.passthrough_data) + self.rec(expr.send.data) + + def map_distributed_recv(self, expr: DistributedRecv) -> None: + for dim in expr.shape: + if isinstance(dim, Array): + self.nusers[dim] += 1 + self.rec(dim) + def get_nusers(outputs: Union[Array, DictOfNamedArrays]) -> Mapping[Array, int]: """ diff --git a/pytato/transform.py b/pytato/transform.py index 49055f5..d38edc2 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -985,6 +985,28 @@ class MPMSMaterializer(Mapper): # loopy call result is always materialized return MPMSMaterializerAccumulator(frozenset([expr]), expr) + def map_distributed_send_ref_holder(self, + expr: DistributedSendRefHolder + ) -> MPMSMaterializerAccumulator: + from pytato.distributed import (DistributedSendRefHolder, + DistributedSend) + rec_passthrough = self.rec(expr.passthrough_data) + rec_send_data = self.rec(expr.send.data) + new_expr = DistributedSendRefHolder( + send=DistributedSend(rec_send_data.expr, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag, + tags=expr.send.tags), + passthrough_data=rec_passthrough.expr, + tags=expr.tags, + ) + return MPMSMaterializerAccumulator( + rec_passthrough.materialized_predecessors, new_expr) + + def map_distributed_recv(self, expr: DistributedRecv + ) -> MPMSMaterializerAccumulator: + return MPMSMaterializerAccumulator(frozenset([expr]), expr) + # }}} -- GitLab