diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 0ab13f05351f6ff23f6cd1a7bd96587526a30c96..7a1f3a4188003a5f4f08ff8fbeb6f6b8470ac42a 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -39,7 +39,7 @@ from grudge.shortcuts import set_up_rk4 def simple_mpi_communication_entrypoint(): cl_ctx = cl.create_some_context() queue = cl.CommandQueue(cl_ctx) - from meshmode.distributed import MPIMeshDistributor + from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis from mpi4py import MPI comm = MPI.COMM_WORLD @@ -53,11 +53,7 @@ def simple_mpi_communication_entrypoint(): b=(1,)*2, n=(3,)*2) - from pymetis import part_graph - _, p = part_graph(num_parts, - xadj=mesh.nodal_adjacency.neighbors_starts.tolist(), - adjncy=mesh.nodal_adjacency.neighbors.tolist()) - part_per_element = np.array(p) + part_per_element = get_partition_by_pymetis(mesh, num_parts) local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts) else: @@ -108,7 +104,7 @@ def mpi_communication_entrypoint(): i_local_rank = comm.Get_rank() num_parts = comm.Get_size() - from meshmode.distributed import MPIMeshDistributor + from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis mesh_dist = MPIMeshDistributor(comm) dim = 2 @@ -121,11 +117,7 @@ def mpi_communication_entrypoint(): b=(0.5,)*dim, n=(16,)*dim) - from pymetis import part_graph - _, p = part_graph(num_parts, - xadj=mesh.nodal_adjacency.neighbors_starts.tolist(), - adjncy=mesh.nodal_adjacency.neighbors.tolist()) - part_per_element = np.array(p) + part_per_element = get_partition_by_pymetis(mesh, num_parts) local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts) else: @@ -261,6 +253,7 @@ def mpi_communication_entrypoint(): @pytest.mark.parametrize("num_ranks", [3]) def test_mpi(num_ranks): pytest.importorskip("mpi4py") + pytest.importorskip("pymetis") from subprocess import check_call import sys @@ -277,6 +270,7 @@ def test_mpi(num_ranks): @pytest.mark.parametrize("num_ranks", [3]) def test_simple_mpi(): pytest.importorskip("mpi4py") + pytest.importorskip("pymetis") from subprocess import check_call import sys