From 3791c78c52ef3c41d3a8a0daa59b5b4241899388 Mon Sep 17 00:00:00 2001 From: Ellis <eshoag@illinois.edu> Date: Thu, 18 Jan 2018 12:20:24 -0600 Subject: [PATCH] Add simple test case --- examples/wave/wave-min.py | 2 - test/test_mpi_communication.py | 221 +++++++++++---------------------- 2 files changed, 70 insertions(+), 153 deletions(-) diff --git a/examples/wave/wave-min.py b/examples/wave/wave-min.py index bd3424b..aa119aa 100644 --- a/examples/wave/wave-min.py +++ b/examples/wave/wave-min.py @@ -84,8 +84,6 @@ def main(write_output=True, order=4): # print(sym.pretty(op.sym_operator())) bound_op = bind(discr, op.sym_operator()) - print(bound_op) - 1/0 def rhs(t, w): return bound_op(queue, t=t, w=w) diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 30710d6..05def1d 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -36,153 +36,66 @@ from grudge import sym, bind, Discretization from grudge.shortcuts import set_up_rk4 -# TODO: Make new test -# Create a partitioned mesh and apply sin(2x + 3y) to its field -# If everything is working, the boundaries of the partitions should be continuous -# Look at int_tpair -# Interpolate volume to boundary, ask for the opposite partition at the boundary -# then compare -# def mpi_communication_entrypoint(): -# cl_ctx = cl.create_some_context() -# queue = cl.CommandQueue(cl_ctx) -# from meshmode.distributed import MPIMeshDistributor -# -# from mpi4py import MPI -# comm = MPI.COMM_WORLD -# rank = comm.Get_rank() -# num_parts = comm.Get_size() -# -# mesh_dist = MPIMeshDistributor(comm) -# -# dims = 2 -# dt = 0.04 -# order = 6 -# -# if mesh_dist.is_mananger_rank(): -# from meshmode.mesh.generation import generate_regular_rect_mesh -# mesh = generate_regular_rect_mesh(a=(-0.5,)*dims, -# b=(0.5,)*dims, -# n=(16,)*dims) -# -# from pymetis import part_graph -# _, p = part_graph(num_parts, -# xadj=mesh.nodal_adjacency.neighbors_starts.tolist(), -# adjncy=mesh.nodal_adjacency.neighbors.tolist()) -# part_per_element = np.array(p) -# -# local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts) -# else: -# local_mesh = mesh_dist.receive_mesh_part() -# -# vol_discr = Discretization(cl_ctx, local_mesh, order=order) -# -# if 0: -# sym_x = sym.nodes(local_mesh.dim) -# myfunc_symb = sym.sin(np.dot(sym_x, [2, 3])) -# myfunc = bind(vol_discr, myfunc_symb)(queue) -# -# sym_all_faces_func = sym.cse( -# sym.interp("vol", "all_faces")(sym.var("myfunc"))) -# sym_int_faces_func = sym.cse( -# sym.interp("vol", "int_faces")(sym.var("myfunc"))) -# sym_bdry_faces_func = sym.cse( -# sym.interp(sym.BTAG_ALL, "all_faces")( -# sym.interp("vol", sym.BTAG_ALL)(sym.var("myfunc")))) -# -# bound_face_swap = bind(vol_discr, -# sym.interp("int_faces", "all_faces")( -# sym.OppositeInteriorFaceSwap("int_faces")( -# sym_int_faces_func) -# ) - (sym_all_faces_func - sym_bdry_faces_func) -# ) -# -# hopefully_zero = bound_face_swap(queue, myfunc=myfunc) -# np.set_printoptions(threshold=100000000, suppress=True) -# print(hopefully_zero) -# -# import numpy.linalg as la -# print(la.norm(hopefully_zero.get())) -# else: -# sym_x = sym.nodes(local_mesh.dim) -# myfunc_symb = sym.sin(np.dot(sym_x, [2, 3])) -# myfunc = bind(vol_discr, myfunc_symb)(queue) -# -# sym_all_faces_func = sym.cse( -# sym.interp("vol", "all_faces")(sym.var("myfunc")) -# - sym.interp(sym.BTAG_ALL, "all_faces")( -# sym.interp("vol", sym.BTAG_ALL)(sym.var("myfunc"))) -# ) -# sym_int_faces_func = sym.cse( -# sym.interp("vol", "int_faces")(sym.var("myfunc"))) -# -# swapped = bind(vol_discr, -# sym.interp("int_faces", "all_faces")( -# sym.OppositeInteriorFaceSwap("int_faces")( -# sym_int_faces_func) -# ))(queue, myfunc=myfunc) -# unswapped = bind(vol_discr, sym_all_faces_func)(queue, myfunc=myfunc) -# -# together = np.zeros((3,)+swapped.shape) -# print(together.shape) -# together[0] = swapped.get() -# together[1] = unswapped.get() -# together[2] = together[1]-together[0] -# -# np.set_printoptions(threshold=100000000, suppress=True, linewidth=150) -# print(together.T) -# -# import numpy.linalg as la -# print(la.norm(hopefully_zero.get())) -# 1/0 -# -# w = sym.make_sym_array("w", vol_discr.dim+1) -# operator = sym.InverseMassOperator()( -# sym.FaceMassOperator()(sym.int_tpair(w))) -# -# # print(sym.pretty(operator) -# bound_op = bind(vol_discr, operator) -# # print(bound_op) -# # 1/0 -# -# def rhs(t, w): -# return bound_op(queue, t=t, w=w) -# -# from pytools.obj_array import join_fields -# fields = join_fields(vol_discr.zeros(queue), -# [vol_discr.zeros(queue) for i in range(vol_discr.dim)]) -# -# dt_stepper = set_up_rk4("w", dt, fields, rhs) -# -# final_t = 10 -# nsteps = int(final_t/dt) -# print("rank=%d dt=%g nsteps=%d" % (rank, dt, nsteps)) -# -# from grudge.shortcuts import make_visualizer -# vis = make_visualizer(vol_discr, vis_order=order) -# -# step = 0 -# -# norm = bind(vol_discr, sym.norm(2, sym.var("u"))) -# -# from time import time -# t_last_step = time() -# -# for event in dt_stepper.run(t_end=final_t): -# if isinstance(event, dt_stepper.StateComputed): -# assert event.component_id == "w" -# -# step += 1 -# -# print(step, event.t, norm(queue, u=event.state_component[0]), -# time()-t_last_step) -# if step % 10 == 0: -# vis.write_vtk_file("rank%d-fld-%04d.vtu" % (rank, step), -# [ -# ("u", event.state_component[0]), -# ("v", event.state_component[1:]), -# ]) -# t_last_step = time() -# logger.debug("Rank %d exiting", rank) +def boundary_communication_entrypoint(): + cl_ctx = cl.create_some_context() + queue = cl.CommandQueue(cl_ctx) + from meshmode.distributed import MPIMeshDistributor + + from mpi4py import MPI + comm = MPI.COMM_WORLD + num_parts = comm.Get_size() + + mesh_dist = MPIMeshDistributor(comm) + + order = 2 + + if mesh_dist.is_mananger_rank(): + from meshmode.mesh.generation import generate_regular_rect_mesh + mesh = generate_regular_rect_mesh(a=(-0.5,)*2, + b=(0.5,)*2, + n=(3,)*2) + + from pymetis import part_graph + _, p = part_graph(num_parts, + xadj=mesh.nodal_adjacency.neighbors_starts.tolist(), + adjncy=mesh.nodal_adjacency.neighbors.tolist()) + part_per_element = np.array(p) + + local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts) + else: + local_mesh = mesh_dist.receive_mesh_part() + + vol_discr = Discretization(cl_ctx, local_mesh, order=order) + + sym_x = sym.nodes(local_mesh.dim) + myfunc_symb = sym.sin(np.dot(sym_x, [2, 3])) + myfunc = bind(vol_discr, myfunc_symb)(queue) + + sym_all_faces_func = sym.cse( + sym.interp("vol", "all_faces")(sym.var("myfunc"))) + sym_int_faces_func = sym.cse( + sym.interp("vol", "int_faces")(sym.var("myfunc"))) + sym_bdry_faces_func = sym.cse( + sym.interp(sym.BTAG_ALL, "all_faces")( + sym.interp("vol", sym.BTAG_ALL)(sym.var("myfunc")))) + + bound_face_swap = bind(vol_discr, + sym.interp("int_faces", "all_faces")( + sym.OppositeInteriorFaceSwap("int_faces")( + sym_int_faces_func) + ) - (sym_all_faces_func - sym_bdry_faces_func) + ) + + hopefully_zero = bound_face_swap(queue, myfunc=myfunc) + import numpy.linalg as la + error = la.norm(hopefully_zero.get()) + + np.set_printoptions(threshold=100000000, suppress=True) + print(hopefully_zero) + print(error) + + assert error < 1e-14 + def mpi_communication_entrypoint(): cl_ctx = cl.create_some_context() @@ -295,15 +208,19 @@ def mpi_communication_entrypoint(): # {{{ MPI test pytest entrypoint @pytest.mark.mpi -@pytest.mark.parametrize("num_partitions", [2]) -def test_mpi_communication(num_partitions): +@pytest.mark.parametrize("testcase", [ + # "MPI_COMMUNICATION", + "BOUNDARY_COMMUNICATION" + ]) +@pytest.mark.parametrize("num_ranks", [2]) +def test_mpi(testcase, num_ranks): pytest.importorskip("mpi4py") - num_ranks = num_partitions from subprocess import check_call import sys newenv = os.environ.copy() newenv["RUN_WITHIN_MPI"] = "1" + newenv[testcase] = "1" check_call([ "mpiexec", "-np", str(num_ranks), "-x", "RUN_WITHIN_MPI", sys.executable, __file__], @@ -313,8 +230,10 @@ def test_mpi_communication(num_partitions): if __name__ == "__main__": - if "RUN_WITHIN_MPI" in os.environ: + if "MPI_COMMUNICATION" in os.environ: mpi_communication_entrypoint() + elif "BOUNDARY_COMMUNICATION" in os.environ: + boundary_communication_entrypoint() else: import sys if len(sys.argv) > 1: -- GitLab