diff --git a/pytato/distributed.py b/pytato/distributed.py index 6c52ccfa7cbb1992777a9076919221398c626b6c..b420207b9bd7c48851916d2c6333e6b0178ef709 100644 --- a/pytato/distributed.py +++ b/pytato/distributed.py @@ -152,6 +152,13 @@ class DistributedSend(Taggable): and self.comm_tag == other.comm_tag and self.tags == other.tags) + def _with_new_tags(self, tags: FrozenSet[Tag]) -> DistributedSend: + return DistributedSend( + data=self.data, + dest_rank=self.dest_rank, + comm_tag=self.comm_tag, + tags=tags) + def copy(self, **kwargs: Any) -> DistributedSend: data: Optional[Array] = kwargs.get("data") dest_rank: Optional[int] = kwargs.get("dest_rank")