diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 1549d9755fbd03dad162a1c9029751b1852c4891..3da47b91e0744eb566c6317352fa684be37de6f0 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -104,7 +104,7 @@ def _map_distributed_graph_partition_nodes( parts={ pid: replace(part, input_name_to_recv_node={ - in_name: map_array(recv) + in_name: cast(DistributedRecv, map_array(recv)) for in_name, recv in part.input_name_to_recv_node.items()}, output_name_to_send_node={ out_name: map_send(send)