diff --git a/grudge/discretization.py b/grudge/discretization.py index f897b73fe32405de8a4aae99ac760771d92051e1..b218e56fa05b4429f3b597f5fbefe4ff01227ff1 100644 --- a/grudge/discretization.py +++ b/grudge/discretization.py @@ -115,33 +115,31 @@ class DGDiscretizationWithBoundaries: def _set_up_distributed_communication(self, mpi_communicator, array_context): from_dd = sym.DOFDesc("vol", sym.QTAG_NONE) + boundary_connections = {} + from meshmode.distributed import get_connected_partitions connected_parts = get_connected_partitions(self._volume_discr.mesh) - if mpi_communicator is None and connected_parts: - raise RuntimeError("must supply an MPI communicator when using a " + if connected_parts: + if mpi_communicator is None: + raise RuntimeError("must supply an MPI communicator when using a " "distributed mesh") - grp_factory = self.group_factory_for_quadrature_tag(sym.QTAG_NONE) - - setup_helpers = {} - boundary_connections = {} - - from meshmode.distributed import MPIBoundaryCommSetupHelper - for i_remote_part in connected_parts: - conn = self.connection_from_dds( - from_dd, - sym.DOFDesc(sym.BTAG_PARTITION(i_remote_part), sym.QTAG_NONE)) - setup_helper = setup_helpers[i_remote_part] = MPIBoundaryCommSetupHelper( - mpi_communicator, array_context, conn, - i_remote_part, grp_factory) - setup_helper.post_send() - - for i_remote_part, setup_helper in setup_helpers.items(): - boundary_connections[i_remote_part] = setup_helper.recv() - - for setup_helper in setup_helpers.values(): - setup_helper.complete_send() + grp_factory = self.group_factory_for_quadrature_tag(sym.QTAG_NONE) + + local_boundary_connections = {} + for i_remote_part in connected_parts: + local_boundary_connections[i_remote_part] = self.connection_from_dds( + from_dd, sym.DOFDesc(sym.BTAG_PARTITION(i_remote_part), + sym.QTAG_NONE)) + + from meshmode.distributed import MPIBoundaryCommSetupHelper + with MPIBoundaryCommSetupHelper(mpi_communicator, array_context, + connected_parts, local_boundary_connections, + grp_factory) as bdry_setup_helper: + while conns := bdry_setup_helper.complete_some(): + for i_remote_part, conn in conns.items(): + boundary_connections[i_remote_part] = conn return boundary_connections