diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index a343beb0a7b81b35df4f7144fb815c4f9b60eb22..4d3026b35e3449d68c7d28afd8dd799672532749 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -27,15 +27,19 @@ THE SOFTWARE. import pytest import os - +import numpy as np +import pyopencl as cl import logging logger = logging.getLogger(__name__) -import numpy as np +from grudge import sym, bind, Discretization +from grudge.shortcuts import set_up_rk4 def mpi_communication_entrypoint(): - from meshmode.distributed import MPIMeshDistributor, MPIBoundaryCommunicator + cl_ctx = cl.create_some_context() + queue = cl.CommandQueue(cl_ctx) + from meshmode.distributed import MPIMeshDistributor from mpi4py import MPI comm = MPI.COMM_WORLD @@ -44,30 +48,97 @@ def mpi_communication_entrypoint(): mesh_dist = MPIMeshDistributor(comm) - if mesh_dist.is_mananger_rank(): - np.random.seed(42) - from meshmode.mesh.generation import generate_warped_rect_mesh - meshes = [generate_warped_rect_mesh(3, order=4, n=4) for _ in range(2)] + dims = 2 + dt = 0.04 + order = 4 - from meshmode.mesh.processing import merge_disjoint_meshes - mesh = merge_disjoint_meshes(meshes) + if mesh_dist.is_mananger_rank(): + from meshmode.mesh.generation import generate_regular_rect_mesh + mesh = generate_regular_rect_mesh(a=(-0.5,)*dims, + b=(0.5,)*dims, + n=(16,)*dims) - part_per_element = np.random.randint(num_parts, size=mesh.nelements) + from pymetis import part_graph + _, p = part_graph(num_parts, + xadj=mesh.nodal_adjacency.neighbors_starts.tolist(), + adjncy=mesh.nodal_adjacency.neighbors.tolist()) + part_per_element = np.array(p) local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts) else: local_mesh = mesh_dist.receive_mesh_part() - from meshmode.discretization.poly_element\ - import PolynomialWarpAndBlendGroupFactory - group_factory = PolynomialWarpAndBlendGroupFactory(4) - import pyopencl as cl - cl_ctx = cl.create_some_context() - queue = cl.CommandQueue(cl_ctx) + vol_discr = Discretization(cl_ctx, local_mesh, order=order) + + source_center = np.array([0.1, 0.22, 0.33])[:local_mesh.dim] + source_width = 0.05 + source_omega = 3 + + sym_x = sym.nodes(local_mesh.dim) + sym_source_center_dist = sym_x - source_center + sym_t = sym.ScalarVariable("t") + + from grudge.models.wave import StrongWaveOperator + from meshmode.mesh import BTAG_ALL, BTAG_NONE + op = StrongWaveOperator(-0.1, vol_discr.dim, + source_f=( + sym.sin(source_omega*sym_t) + * sym.exp( + -np.dot(sym_source_center_dist, sym_source_center_dist) + / source_width**2)), + dirichlet_tag=BTAG_NONE, + neumann_tag=BTAG_NONE, + radiation_tag=BTAG_ALL, + flux_type="upwind") + + from pytools.obj_array import join_fields + fields = join_fields(vol_discr.zeros(queue), + [vol_discr.zeros(queue) for i in range(vol_discr.dim)]) + + # FIXME + #dt = op.estimate_rk4_timestep(vol_discr, fields=fields) + + op.check_bc_coverage(local_mesh) + + # print(sym.pretty(op.sym_operator())) + bound_op = bind(vol_discr, op.sym_operator()) + # print(bound_op) + # 1/0 + + def rhs(t, w): + return bound_op(queue, t=t, w=w) + + dt_stepper = set_up_rk4("w", dt, fields, rhs) + + final_t = 10 + nsteps = int(final_t/dt) + print("dt=%g nsteps=%d" % (dt, nsteps)) + + from grudge.shortcuts import make_visualizer + vis = make_visualizer(vol_discr, vis_order=order) + + step = 0 + + norm = bind(vol_discr, sym.norm(2, sym.var("u"))) + + from time import time + t_last_step = time() + + for event in dt_stepper.run(t_end=final_t): + if isinstance(event, dt_stepper.StateComputed): + assert event.component_id == "w" - from meshmode.discretization import Discretization - vol_discr = Discretization(cl_ctx, local_mesh, group_factory) + step += 1 + print(step, event.t, norm(queue, u=event.state_component[0]), + time()-t_last_step) + if step % 10 == 0: + vis.write_vtk_file("r%d-fld-%04d.vtu" % (rank, step), + [ + ("u", event.state_component[0]), + ("v", event.state_component[1:]), + ]) + t_last_step = time() logger.debug("Rank %d exiting", rank)