diff --git a/grudge/discretization.py b/grudge/discretization.py index e700e1812f705ae6fd5aedfd93df41b5c3ebf3b6..54bad00c7159f1d6cd1b5f330abd0b0dfd413c88 100644 --- a/grudge/discretization.py +++ b/grudge/discretization.py @@ -115,30 +115,33 @@ class DGDiscretizationWithBoundaries: def _set_up_distributed_communication(self, mpi_communicator, array_context): from_dd = sym.DOFDesc("vol", sym.QTAG_NONE) + boundary_connections = {} + from meshmode.distributed import get_connected_partitions connected_parts = get_connected_partitions(self._volume_discr.mesh) - if mpi_communicator is None and connected_parts: - raise RuntimeError("must supply an MPI communicator when using a " + if connected_parts: + if mpi_communicator is None: + raise RuntimeError("must supply an MPI communicator when using a " "distributed mesh") - grp_factory = self.group_factory_for_quadrature_tag(sym.QTAG_NONE) - - setup_helpers = {} - boundary_connections = {} - - from meshmode.distributed import MPIBoundaryCommSetupHelper - for i_remote_part in connected_parts: - conn = self.connection_from_dds( - from_dd, - sym.DOFDesc(sym.BTAG_PARTITION(i_remote_part), sym.QTAG_NONE)) - setup_helper = setup_helpers[i_remote_part] = MPIBoundaryCommSetupHelper( - mpi_communicator, array_context, conn, - i_remote_part, grp_factory) - setup_helper.post_sends() - - for i_remote_part, setup_helper in setup_helpers.items(): - boundary_connections[i_remote_part] = setup_helper.complete_setup() + grp_factory = self.group_factory_for_quadrature_tag(sym.QTAG_NONE) + + local_boundary_connections = {} + for i_remote_part in connected_parts: + local_boundary_connections[i_remote_part] = self.connection_from_dds( + from_dd, sym.DOFDesc(sym.BTAG_PARTITION(i_remote_part), + sym.QTAG_NONE)) + + from meshmode.distributed import MPIBoundaryCommSetupHelper + with MPIBoundaryCommSetupHelper(mpi_communicator, array_context, + local_boundary_connections, grp_factory) as bdry_setup_helper: + while True: + conns = bdry_setup_helper.complete_some() + if not conns: + break + for i_remote_part, conn in conns.items(): + boundary_connections[i_remote_part] = conn return boundary_connections