From 924bf21ade8cda6de4bba619b397a82d3bb544da Mon Sep 17 00:00:00 2001
From: Ellis <eshoag@illinois.edu>
Date: Thu, 25 Jan 2018 11:51:30 -0600
Subject: [PATCH] Add simple mpi test

---
 test/test_mpi_communication.py | 27 +++++++++++++++------------
 1 file changed, 15 insertions(+), 12 deletions(-)

diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py
index f3a8118..6244dcc 100644
--- a/test/test_mpi_communication.py
+++ b/test/test_mpi_communication.py
@@ -36,7 +36,7 @@ from grudge import sym, bind, Discretization
 from grudge.shortcuts import set_up_rk4
 
 
-def simple_mpi_communication_entrypoint():
+def simple_mpi_communication_entrypoint(order):
     cl_ctx = cl.create_some_context()
     queue = cl.CommandQueue(cl_ctx)
     from meshmode.distributed import MPIMeshDistributor
@@ -47,19 +47,19 @@ def simple_mpi_communication_entrypoint():
 
     mesh_dist = MPIMeshDistributor(comm)
 
-    order = 2
-
     if mesh_dist.is_mananger_rank():
         from meshmode.mesh.generation import generate_regular_rect_mesh
-        mesh = generate_regular_rect_mesh(a=(-0.5,)*2,
-                                          b=(0.5,)*2,
+        mesh = generate_regular_rect_mesh(a=(-1,)*2,
+                                          b=(1,)*2,
                                           n=(3,)*2)
 
-        from pymetis import part_graph
-        _, p = part_graph(num_parts,
-                          xadj=mesh.nodal_adjacency.neighbors_starts.tolist(),
-                          adjncy=mesh.nodal_adjacency.neighbors.tolist())
-        part_per_element = np.array(p)
+        # This gives [0, 0, 0, 1, 0, 1, 1, 1]
+        # from pymetis import part_graph
+        # _, p = part_graph(num_parts,
+        #                   xadj=mesh.nodal_adjacency.neighbors_starts.tolist(),
+        #                   adjncy=mesh.nodal_adjacency.neighbors.tolist())
+        # part_per_element = np.array(p)
+        part_per_element = np.array([0, 0, 0, 1, 0, 1, 1, 1])
 
         local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
     else:
@@ -225,7 +225,8 @@ def test_mpi(num_ranks):
 
 @pytest.mark.mpi
 @pytest.mark.parametrize("num_ranks", [2])
-def test_simple_mpi(num_ranks):
+@pytest.mark.parametrize("order", [2])
+def test_simple_mpi(num_ranks, order):
     pytest.importorskip("mpi4py")
 
     from subprocess import check_call
@@ -233,6 +234,7 @@ def test_simple_mpi(num_ranks):
     newenv = os.environ.copy()
     newenv["RUN_WITHIN_MPI"] = "1"
     newenv["TEST_SIMPLE_MPI_COMMUNICATION"] = "1"
+    newenv["order"] = str(order)
     check_call([
         "mpiexec", "-np", str(num_ranks), "-x", "RUN_WITHIN_MPI",
         sys.executable, __file__],
@@ -246,7 +248,8 @@ if __name__ == "__main__":
         if "TEST_MPI_COMMUNICATION" in os.environ:
             mpi_communication_entrypoint()
         elif "TEST_SIMPLE_MPI_COMMUNICATION" in os.environ:
-            simple_mpi_communication_entrypoint()
+            order = int(os.environ["order"])
+            simple_mpi_communication_entrypoint(order)
     else:
         import sys
         if len(sys.argv) > 1:
-- 
GitLab