diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py
index 1549d9755fbd03dad162a1c9029751b1852c4891..3da47b91e0744eb566c6317352fa684be37de6f0 100644
--- a/pytato/distributed/partition.py
+++ b/pytato/distributed/partition.py
@@ -104,7 +104,7 @@ def _map_distributed_graph_partition_nodes(
             parts={
                 pid: replace(part,
                     input_name_to_recv_node={
-                        in_name: map_array(recv)
+                        in_name: cast(DistributedRecv, map_array(recv))
                         for in_name, recv in part.input_name_to_recv_node.items()},
                     output_name_to_send_node={
                         out_name: map_send(send)