From 924bf21ade8cda6de4bba619b397a82d3bb544da Mon Sep 17 00:00:00 2001 From: Ellis <eshoag@illinois.edu> Date: Thu, 25 Jan 2018 11:51:30 -0600 Subject: [PATCH] Add simple mpi test --- test/test_mpi_communication.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index f3a8118..6244dcc 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -36,7 +36,7 @@ from grudge import sym, bind, Discretization from grudge.shortcuts import set_up_rk4 -def simple_mpi_communication_entrypoint(): +def simple_mpi_communication_entrypoint(order): cl_ctx = cl.create_some_context() queue = cl.CommandQueue(cl_ctx) from meshmode.distributed import MPIMeshDistributor @@ -47,19 +47,19 @@ def simple_mpi_communication_entrypoint(): mesh_dist = MPIMeshDistributor(comm) - order = 2 - if mesh_dist.is_mananger_rank(): from meshmode.mesh.generation import generate_regular_rect_mesh - mesh = generate_regular_rect_mesh(a=(-0.5,)*2, - b=(0.5,)*2, + mesh = generate_regular_rect_mesh(a=(-1,)*2, + b=(1,)*2, n=(3,)*2) - from pymetis import part_graph - _, p = part_graph(num_parts, - xadj=mesh.nodal_adjacency.neighbors_starts.tolist(), - adjncy=mesh.nodal_adjacency.neighbors.tolist()) - part_per_element = np.array(p) + # This gives [0, 0, 0, 1, 0, 1, 1, 1] + # from pymetis import part_graph + # _, p = part_graph(num_parts, + # xadj=mesh.nodal_adjacency.neighbors_starts.tolist(), + # adjncy=mesh.nodal_adjacency.neighbors.tolist()) + # part_per_element = np.array(p) + part_per_element = np.array([0, 0, 0, 1, 0, 1, 1, 1]) local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts) else: @@ -225,7 +225,8 @@ def test_mpi(num_ranks): @pytest.mark.mpi @pytest.mark.parametrize("num_ranks", [2]) -def test_simple_mpi(num_ranks): +@pytest.mark.parametrize("order", [2]) +def test_simple_mpi(num_ranks, order): pytest.importorskip("mpi4py") from subprocess import check_call @@ -233,6 +234,7 @@ def test_simple_mpi(num_ranks): newenv = os.environ.copy() newenv["RUN_WITHIN_MPI"] = "1" newenv["TEST_SIMPLE_MPI_COMMUNICATION"] = "1" + newenv["order"] = str(order) check_call([ "mpiexec", "-np", str(num_ranks), "-x", "RUN_WITHIN_MPI", sys.executable, __file__], @@ -246,7 +248,8 @@ if __name__ == "__main__": if "TEST_MPI_COMMUNICATION" in os.environ: mpi_communication_entrypoint() elif "TEST_SIMPLE_MPI_COMMUNICATION" in os.environ: - simple_mpi_communication_entrypoint() + order = int(os.environ["order"]) + simple_mpi_communication_entrypoint(order) else: import sys if len(sys.argv) > 1: -- GitLab