diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 83ec1046ed50c3552d6b562520e08bef484236d5..3b21b6e1512c3064b9f9d429b5043fd9e260c897 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -47,6 +47,13 @@ class NUserCollector(Mapper): """ A :class:`pytato.transform.CachedWalkMapper` that records the number of times an array expression is a direct dependency of other nodes. + + .. note:: + + - We do not consider the :class:`pytato.DistributedSendRefHolder` + a user of :attr:`pytato.DistributedSendRefHolder.send`. This is + because in a data flow sense, the send-ref holder does not use the + send's data. """ def __init__(self) -> None: from collections import defaultdict @@ -141,6 +148,20 @@ class NUserCollector(Mapper): map_data_wrapper = _map_input_base map_size_param = _map_input_base + def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder + ) -> None: + # Note: We do not consider 'expr.send.data' as a predecessor of *expr*, + # as there is no dataflow from *expr.send.data* to *expr* + self.nusers[expr.passthrough_data] += 1 + self.rec(expr.passthrough_data) + self.rec(expr.send.data) + + def map_distributed_recv(self, expr: DistributedRecv) -> None: + for dim in expr.shape: + if isinstance(dim, Array): + self.nusers[dim] += 1 + self.rec(dim) + def get_nusers(outputs: Union[Array, DictOfNamedArrays]) -> Mapping[Array, int]: """ diff --git a/pytato/transform.py b/pytato/transform.py index 49055f5f50c3c6ce83ebf088008907f7df9e532d..d38edc2c72b9018ae34c113645ab92f0316e2ab4 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -985,6 +985,28 @@ class MPMSMaterializer(Mapper): # loopy call result is always materialized return MPMSMaterializerAccumulator(frozenset([expr]), expr) + def map_distributed_send_ref_holder(self, + expr: DistributedSendRefHolder + ) -> MPMSMaterializerAccumulator: + from pytato.distributed import (DistributedSendRefHolder, + DistributedSend) + rec_passthrough = self.rec(expr.passthrough_data) + rec_send_data = self.rec(expr.send.data) + new_expr = DistributedSendRefHolder( + send=DistributedSend(rec_send_data.expr, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag, + tags=expr.send.tags), + passthrough_data=rec_passthrough.expr, + tags=expr.tags, + ) + return MPMSMaterializerAccumulator( + rec_passthrough.materialized_predecessors, new_expr) + + def map_distributed_recv(self, expr: DistributedRecv + ) -> MPMSMaterializerAccumulator: + return MPMSMaterializerAccumulator(frozenset([expr]), expr) + # }}}