From 3791c78c52ef3c41d3a8a0daa59b5b4241899388 Mon Sep 17 00:00:00 2001
From: Ellis <eshoag@illinois.edu>
Date: Thu, 18 Jan 2018 12:20:24 -0600
Subject: [PATCH] Add simple test case

---
 examples/wave/wave-min.py      |   2 -
 test/test_mpi_communication.py | 221 +++++++++++----------------------
 2 files changed, 70 insertions(+), 153 deletions(-)

diff --git a/examples/wave/wave-min.py b/examples/wave/wave-min.py
index bd3424b..aa119aa 100644
--- a/examples/wave/wave-min.py
+++ b/examples/wave/wave-min.py
@@ -84,8 +84,6 @@ def main(write_output=True, order=4):
 
     # print(sym.pretty(op.sym_operator()))
     bound_op = bind(discr, op.sym_operator())
-    print(bound_op)
-    1/0
 
     def rhs(t, w):
         return bound_op(queue, t=t, w=w)
diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py
index 30710d6..05def1d 100644
--- a/test/test_mpi_communication.py
+++ b/test/test_mpi_communication.py
@@ -36,153 +36,66 @@ from grudge import sym, bind, Discretization
 from grudge.shortcuts import set_up_rk4
 
 
-# TODO: Make new test
-# Create a partitioned mesh and apply sin(2x + 3y) to its field
-# If everything is working, the boundaries of the partitions should be continuous
-# Look at int_tpair
-# Interpolate volume to boundary, ask for the opposite partition at the boundary
-# then compare
-# def mpi_communication_entrypoint():
-#     cl_ctx = cl.create_some_context()
-#     queue = cl.CommandQueue(cl_ctx)
-#     from meshmode.distributed import MPIMeshDistributor
-#
-#     from mpi4py import MPI
-#     comm = MPI.COMM_WORLD
-#     rank = comm.Get_rank()
-#     num_parts = comm.Get_size()
-#
-#     mesh_dist = MPIMeshDistributor(comm)
-#
-#     dims = 2
-#     dt = 0.04
-#     order = 6
-#
-#     if mesh_dist.is_mananger_rank():
-#         from meshmode.mesh.generation import generate_regular_rect_mesh
-#         mesh = generate_regular_rect_mesh(a=(-0.5,)*dims,
-#                                           b=(0.5,)*dims,
-#                                           n=(16,)*dims)
-#
-#         from pymetis import part_graph
-#         _, p = part_graph(num_parts,
-#                           xadj=mesh.nodal_adjacency.neighbors_starts.tolist(),
-#                           adjncy=mesh.nodal_adjacency.neighbors.tolist())
-#         part_per_element = np.array(p)
-#
-#         local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
-#     else:
-#         local_mesh = mesh_dist.receive_mesh_part()
-#
-#     vol_discr = Discretization(cl_ctx, local_mesh, order=order)
-#
-#     if 0:
-#         sym_x = sym.nodes(local_mesh.dim)
-#         myfunc_symb = sym.sin(np.dot(sym_x, [2, 3]))
-#         myfunc = bind(vol_discr, myfunc_symb)(queue)
-#
-#         sym_all_faces_func = sym.cse(
-#             sym.interp("vol", "all_faces")(sym.var("myfunc")))
-#         sym_int_faces_func = sym.cse(
-#             sym.interp("vol", "int_faces")(sym.var("myfunc")))
-#         sym_bdry_faces_func = sym.cse(
-#             sym.interp(sym.BTAG_ALL, "all_faces")(
-#                 sym.interp("vol", sym.BTAG_ALL)(sym.var("myfunc"))))
-#
-#         bound_face_swap = bind(vol_discr,
-#             sym.interp("int_faces", "all_faces")(
-#                 sym.OppositeInteriorFaceSwap("int_faces")(
-#                     sym_int_faces_func)
-#                 ) - (sym_all_faces_func - sym_bdry_faces_func)
-#                 )
-#
-#         hopefully_zero = bound_face_swap(queue, myfunc=myfunc)
-#         np.set_printoptions(threshold=100000000, suppress=True)
-#         print(hopefully_zero)
-#
-#         import numpy.linalg as la
-#         print(la.norm(hopefully_zero.get()))
-#     else:
-#         sym_x = sym.nodes(local_mesh.dim)
-#         myfunc_symb = sym.sin(np.dot(sym_x, [2, 3]))
-#         myfunc = bind(vol_discr, myfunc_symb)(queue)
-#
-#         sym_all_faces_func = sym.cse(
-#             sym.interp("vol", "all_faces")(sym.var("myfunc"))
-#             - sym.interp(sym.BTAG_ALL, "all_faces")(
-#                 sym.interp("vol", sym.BTAG_ALL)(sym.var("myfunc")))
-#             )
-#         sym_int_faces_func = sym.cse(
-#             sym.interp("vol", "int_faces")(sym.var("myfunc")))
-#
-#         swapped = bind(vol_discr,
-#             sym.interp("int_faces", "all_faces")(
-#                 sym.OppositeInteriorFaceSwap("int_faces")(
-#                     sym_int_faces_func)
-#                 ))(queue, myfunc=myfunc)
-#         unswapped = bind(vol_discr, sym_all_faces_func)(queue, myfunc=myfunc)
-#
-#         together = np.zeros((3,)+swapped.shape)
-#         print(together.shape)
-#         together[0] = swapped.get()
-#         together[1] = unswapped.get()
-#         together[2] = together[1]-together[0]
-#
-#         np.set_printoptions(threshold=100000000, suppress=True, linewidth=150)
-#         print(together.T)
-#
-#         import numpy.linalg as la
-#         print(la.norm(hopefully_zero.get()))
-#     1/0
-#
-#     w = sym.make_sym_array("w", vol_discr.dim+1)
-#     operator = sym.InverseMassOperator()(
-#                     sym.FaceMassOperator()(sym.int_tpair(w)))
-#
-#     # print(sym.pretty(operator)
-#     bound_op = bind(vol_discr, operator)
-#     # print(bound_op)
-#     # 1/0
-#
-#     def rhs(t, w):
-#         return bound_op(queue, t=t, w=w)
-#
-#     from pytools.obj_array import join_fields
-#     fields = join_fields(vol_discr.zeros(queue),
-#             [vol_discr.zeros(queue) for i in range(vol_discr.dim)])
-#
-#     dt_stepper = set_up_rk4("w", dt, fields, rhs)
-#
-#     final_t = 10
-#     nsteps = int(final_t/dt)
-#     print("rank=%d dt=%g nsteps=%d" % (rank, dt, nsteps))
-#
-#     from grudge.shortcuts import make_visualizer
-#     vis = make_visualizer(vol_discr, vis_order=order)
-#
-#     step = 0
-#
-#     norm = bind(vol_discr, sym.norm(2, sym.var("u")))
-#
-#     from time import time
-#     t_last_step = time()
-#
-#     for event in dt_stepper.run(t_end=final_t):
-#         if isinstance(event, dt_stepper.StateComputed):
-#             assert event.component_id == "w"
-#
-#             step += 1
-#
-#             print(step, event.t, norm(queue, u=event.state_component[0]),
-#                     time()-t_last_step)
-#             if step % 10 == 0:
-#                 vis.write_vtk_file("rank%d-fld-%04d.vtu" % (rank, step),
-#                         [
-#                             ("u", event.state_component[0]),
-#                             ("v", event.state_component[1:]),
-#                             ])
-#             t_last_step = time()
-#     logger.debug("Rank %d exiting", rank)
+def boundary_communication_entrypoint():
+    cl_ctx = cl.create_some_context()
+    queue = cl.CommandQueue(cl_ctx)
+    from meshmode.distributed import MPIMeshDistributor
+
+    from mpi4py import MPI
+    comm = MPI.COMM_WORLD
+    num_parts = comm.Get_size()
+
+    mesh_dist = MPIMeshDistributor(comm)
+
+    order = 2
+
+    if mesh_dist.is_mananger_rank():
+        from meshmode.mesh.generation import generate_regular_rect_mesh
+        mesh = generate_regular_rect_mesh(a=(-0.5,)*2,
+                                          b=(0.5,)*2,
+                                          n=(3,)*2)
+
+        from pymetis import part_graph
+        _, p = part_graph(num_parts,
+                          xadj=mesh.nodal_adjacency.neighbors_starts.tolist(),
+                          adjncy=mesh.nodal_adjacency.neighbors.tolist())
+        part_per_element = np.array(p)
+
+        local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
+    else:
+        local_mesh = mesh_dist.receive_mesh_part()
+
+    vol_discr = Discretization(cl_ctx, local_mesh, order=order)
+
+    sym_x = sym.nodes(local_mesh.dim)
+    myfunc_symb = sym.sin(np.dot(sym_x, [2, 3]))
+    myfunc = bind(vol_discr, myfunc_symb)(queue)
+
+    sym_all_faces_func = sym.cse(
+        sym.interp("vol", "all_faces")(sym.var("myfunc")))
+    sym_int_faces_func = sym.cse(
+        sym.interp("vol", "int_faces")(sym.var("myfunc")))
+    sym_bdry_faces_func = sym.cse(
+        sym.interp(sym.BTAG_ALL, "all_faces")(
+            sym.interp("vol", sym.BTAG_ALL)(sym.var("myfunc"))))
+
+    bound_face_swap = bind(vol_discr,
+        sym.interp("int_faces", "all_faces")(
+            sym.OppositeInteriorFaceSwap("int_faces")(
+                sym_int_faces_func)
+            ) - (sym_all_faces_func - sym_bdry_faces_func)
+            )
+
+    hopefully_zero = bound_face_swap(queue, myfunc=myfunc)
+    import numpy.linalg as la
+    error = la.norm(hopefully_zero.get())
+
+    np.set_printoptions(threshold=100000000, suppress=True)
+    print(hopefully_zero)
+    print(error)
+
+    assert error < 1e-14
+
 
 def mpi_communication_entrypoint():
     cl_ctx = cl.create_some_context()
@@ -295,15 +208,19 @@ def mpi_communication_entrypoint():
 # {{{ MPI test pytest entrypoint
 
 @pytest.mark.mpi
-@pytest.mark.parametrize("num_partitions", [2])
-def test_mpi_communication(num_partitions):
+@pytest.mark.parametrize("testcase", [
+    # "MPI_COMMUNICATION",
+    "BOUNDARY_COMMUNICATION"
+    ])
+@pytest.mark.parametrize("num_ranks", [2])
+def test_mpi(testcase, num_ranks):
     pytest.importorskip("mpi4py")
 
-    num_ranks = num_partitions
     from subprocess import check_call
     import sys
     newenv = os.environ.copy()
     newenv["RUN_WITHIN_MPI"] = "1"
+    newenv[testcase] = "1"
     check_call([
         "mpiexec", "-np", str(num_ranks), "-x", "RUN_WITHIN_MPI",
         sys.executable, __file__],
@@ -313,8 +230,10 @@ def test_mpi_communication(num_partitions):
 
 
 if __name__ == "__main__":
-    if "RUN_WITHIN_MPI" in os.environ:
+    if "MPI_COMMUNICATION" in os.environ:
         mpi_communication_entrypoint()
+    elif "BOUNDARY_COMMUNICATION" in os.environ:
+        boundary_communication_entrypoint()
     else:
         import sys
         if len(sys.argv) > 1:
-- 
GitLab