Skip to content
Snippets Groups Projects
Commit 924bf21a authored by Ellis's avatar Ellis
Browse files

Add simple mpi test

parent 322eacbd
No related branches found
No related tags found
No related merge requests found
...@@ -36,7 +36,7 @@ from grudge import sym, bind, Discretization ...@@ -36,7 +36,7 @@ from grudge import sym, bind, Discretization
from grudge.shortcuts import set_up_rk4 from grudge.shortcuts import set_up_rk4
def simple_mpi_communication_entrypoint(): def simple_mpi_communication_entrypoint(order):
cl_ctx = cl.create_some_context() cl_ctx = cl.create_some_context()
queue = cl.CommandQueue(cl_ctx) queue = cl.CommandQueue(cl_ctx)
from meshmode.distributed import MPIMeshDistributor from meshmode.distributed import MPIMeshDistributor
...@@ -47,19 +47,19 @@ def simple_mpi_communication_entrypoint(): ...@@ -47,19 +47,19 @@ def simple_mpi_communication_entrypoint():
mesh_dist = MPIMeshDistributor(comm) mesh_dist = MPIMeshDistributor(comm)
order = 2
if mesh_dist.is_mananger_rank(): if mesh_dist.is_mananger_rank():
from meshmode.mesh.generation import generate_regular_rect_mesh from meshmode.mesh.generation import generate_regular_rect_mesh
mesh = generate_regular_rect_mesh(a=(-0.5,)*2, mesh = generate_regular_rect_mesh(a=(-1,)*2,
b=(0.5,)*2, b=(1,)*2,
n=(3,)*2) n=(3,)*2)
from pymetis import part_graph # This gives [0, 0, 0, 1, 0, 1, 1, 1]
_, p = part_graph(num_parts, # from pymetis import part_graph
xadj=mesh.nodal_adjacency.neighbors_starts.tolist(), # _, p = part_graph(num_parts,
adjncy=mesh.nodal_adjacency.neighbors.tolist()) # xadj=mesh.nodal_adjacency.neighbors_starts.tolist(),
part_per_element = np.array(p) # adjncy=mesh.nodal_adjacency.neighbors.tolist())
# part_per_element = np.array(p)
part_per_element = np.array([0, 0, 0, 1, 0, 1, 1, 1])
local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts) local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
else: else:
...@@ -225,7 +225,8 @@ def test_mpi(num_ranks): ...@@ -225,7 +225,8 @@ def test_mpi(num_ranks):
@pytest.mark.mpi @pytest.mark.mpi
@pytest.mark.parametrize("num_ranks", [2]) @pytest.mark.parametrize("num_ranks", [2])
def test_simple_mpi(num_ranks): @pytest.mark.parametrize("order", [2])
def test_simple_mpi(num_ranks, order):
pytest.importorskip("mpi4py") pytest.importorskip("mpi4py")
from subprocess import check_call from subprocess import check_call
...@@ -233,6 +234,7 @@ def test_simple_mpi(num_ranks): ...@@ -233,6 +234,7 @@ def test_simple_mpi(num_ranks):
newenv = os.environ.copy() newenv = os.environ.copy()
newenv["RUN_WITHIN_MPI"] = "1" newenv["RUN_WITHIN_MPI"] = "1"
newenv["TEST_SIMPLE_MPI_COMMUNICATION"] = "1" newenv["TEST_SIMPLE_MPI_COMMUNICATION"] = "1"
newenv["order"] = str(order)
check_call([ check_call([
"mpiexec", "-np", str(num_ranks), "-x", "RUN_WITHIN_MPI", "mpiexec", "-np", str(num_ranks), "-x", "RUN_WITHIN_MPI",
sys.executable, __file__], sys.executable, __file__],
...@@ -246,7 +248,8 @@ if __name__ == "__main__": ...@@ -246,7 +248,8 @@ if __name__ == "__main__":
if "TEST_MPI_COMMUNICATION" in os.environ: if "TEST_MPI_COMMUNICATION" in os.environ:
mpi_communication_entrypoint() mpi_communication_entrypoint()
elif "TEST_SIMPLE_MPI_COMMUNICATION" in os.environ: elif "TEST_SIMPLE_MPI_COMMUNICATION" in os.environ:
simple_mpi_communication_entrypoint() order = int(os.environ["order"])
simple_mpi_communication_entrypoint(order)
else: else:
import sys import sys
if len(sys.argv) > 1: if len(sys.argv) > 1:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment