From f54331008713920bdd7a020f528f4a8894b748de Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 18 Jun 2020 16:21:10 -0500
Subject: [PATCH] Eager interface: implement MPI communication, add example

---
 examples/wave/wave-eager-mpi.py | 206 ++++++++++++++++++++++++++++++++
 grudge/eager.py                 |  97 +++++++++++++--
 2 files changed, 295 insertions(+), 8 deletions(-)
 create mode 100644 examples/wave/wave-eager-mpi.py

diff --git a/examples/wave/wave-eager-mpi.py b/examples/wave/wave-eager-mpi.py
new file mode 100644
index 00000000..1d013efc
--- /dev/null
+++ b/examples/wave/wave-eager-mpi.py
@@ -0,0 +1,206 @@
+from __future__ import division, print_function
+
+__copyright__ = "Copyright (C) 2020 Andreas Kloeckner"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+
+import numpy as np
+import numpy.linalg as la  # noqa
+import pyopencl as cl
+
+from pytools.obj_array import flat_obj_array, make_obj_array
+
+from meshmode.array_context import PyOpenCLArrayContext
+from meshmode.dof_array import thaw
+
+from meshmode.mesh import BTAG_ALL, BTAG_NONE  # noqa
+
+from grudge.eager import (
+        EagerDGDiscretization, interior_trace_pair, cross_rank_trace_pairs)
+from grudge.shortcuts import make_visualizer
+from grudge.symbolic.primitives import TracePair
+from mpi4py import MPI
+
+
+# {{{ wave equation bits
+
+def wave_flux(discr, c, w_tpair):
+    u = w_tpair[0]
+    v = w_tpair[1:]
+
+    normal = thaw(u.int.array_context, discr.normal(w_tpair.dd))
+
+    def normal_times(scalar):
+        # workaround for object array behavior
+        return make_obj_array([ni*scalar for ni in normal])
+
+    flux_weak = flat_obj_array(
+            np.dot(v.avg, normal),
+            normal_times(u.avg),
+            )
+
+    # upwind
+    v_jump = np.dot(normal, v.int-v.ext)
+    flux_weak -= flat_obj_array(
+            0.5*(u.int-u.ext),
+            0.5*normal_times(v_jump),
+            )
+
+    return discr.interp(w_tpair.dd, "all_faces", c*flux_weak)
+
+
+def wave_operator(discr, c, w):
+    u = w[0]
+    v = w[1:]
+
+    dir_u = discr.interp("vol", BTAG_ALL, u)
+    dir_v = discr.interp("vol", BTAG_ALL, v)
+    dir_bval = flat_obj_array(dir_u, dir_v)
+    dir_bc = flat_obj_array(-dir_u, dir_v)
+
+    return (
+            discr.inverse_mass(
+                flat_obj_array(
+                    c*discr.weak_div(v),
+                    c*discr.weak_grad(u)
+                    )
+                -  # noqa: W504
+                discr.face_mass(
+                    wave_flux(discr, c=c, w_tpair=interior_trace_pair(discr, w))
+                    + wave_flux(discr, c=c, w_tpair=TracePair(
+                        BTAG_ALL, dir_bval, dir_bc))
+                    + sum(
+                        wave_flux(discr, c=c, w_tpair=tpair)
+                        for tpair in cross_rank_trace_pairs(discr, w))
+                    )
+                )
+                )
+
+# }}}
+
+
+def rk4_step(y, t, h, f):
+    k1 = f(t, y)
+    k2 = f(t+h/2, y + h/2*k1)
+    k3 = f(t+h/2, y + h/2*k2)
+    k4 = f(t+h, y + h*k3)
+    return y + h/6*(k1 + 2*k2 + 2*k3 + k4)
+
+
+def bump(discr, actx, t=0):
+    source_center = np.array([0.2, 0.35, 0.1])[:discr.dim]
+    source_width = 0.05
+    source_omega = 3
+
+    nodes = thaw(actx, discr.nodes())
+    center_dist = flat_obj_array([
+        nodes[i] - source_center[i]
+        for i in range(discr.dim)
+        ])
+
+    return (
+        np.cos(source_omega*t)
+        * actx.np.exp(
+            -np.dot(center_dist, center_dist)
+            / source_width**2))
+
+
+def main():
+    cl_ctx = cl.create_some_context()
+    queue = cl.CommandQueue(cl_ctx)
+    actx = PyOpenCLArrayContext(queue)
+
+    comm = MPI.COMM_WORLD
+    num_parts = comm.Get_size()
+
+    from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
+    mesh_dist = MPIMeshDistributor(comm)
+
+    dim = 2
+    nel_1d = 16
+
+    if mesh_dist.is_mananger_rank():
+        from meshmode.mesh.generation import generate_regular_rect_mesh
+        mesh = generate_regular_rect_mesh(
+                a=(-0.5,)*dim,
+                b=(0.5,)*dim,
+                n=(nel_1d,)*dim)
+
+        print("%d elements" % mesh.nelements)
+
+        part_per_element = get_partition_by_pymetis(mesh, num_parts)
+
+        local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
+
+        del mesh
+
+    else:
+        local_mesh = mesh_dist.receive_mesh_part()
+
+    order = 3
+
+    discr = EagerDGDiscretization(actx, local_mesh, order=order,
+                    mpi_communicator=comm)
+
+    if dim == 2:
+        # no deep meaning here, just a fudge factor
+        dt = 0.75/(nel_1d*order**2)
+    elif dim == 3:
+        # no deep meaning here, just a fudge factor
+        dt = 0.45/(nel_1d*order**2)
+    else:
+        raise ValueError("don't have a stable time step guesstimate")
+
+    fields = flat_obj_array(
+            bump(discr, actx),
+            [discr.zeros(actx) for i in range(discr.dim)]
+            )
+
+    vis = make_visualizer(discr, discr.order+3 if dim == 2 else discr.order)
+
+    def rhs(t, w):
+        return wave_operator(discr, c=1, w=w)
+
+    rank = comm.Get_rank()
+
+    t = 0
+    t_final = 3
+    istep = 0
+    while t < t_final:
+        fields = rk4_step(fields, t, dt, rhs)
+
+        if istep % 10 == 0:
+            print(istep, t, discr.norm(fields[0]))
+            vis.write_vtk_file("fld-wave-eager-mpi-%03d-%04d.vtu" % (rank, istep),
+                    [
+                        ("u", fields[0]),
+                        ("v", fields[1:]),
+                        ])
+
+        t += dt
+        istep += 1
+
+
+if __name__ == "__main__":
+    main()
+
+# vim: foldmethod=marker
diff --git a/grudge/eager.py b/grudge/eager.py
index 33b2c587..05f1bb37 100644
--- a/grudge/eager.py
+++ b/grudge/eager.py
@@ -24,14 +24,16 @@ THE SOFTWARE.
 
 
 import numpy as np  # noqa
-from grudge.discretization import DGDiscretizationWithBoundaries
 from pytools import memoize_method
-from pytools.obj_array import obj_array_vectorize
+from pytools.obj_array import obj_array_vectorize, make_obj_array
 import pyopencl.array as cla  # noqa
 from grudge import sym, bind
 
-from meshmode.mesh import BTAG_ALL, BTAG_NONE  # noqa
-from meshmode.dof_array import freeze, DOFArray
+from meshmode.mesh import BTAG_ALL, BTAG_NONE, BTAG_PARTITION  # noqa
+from meshmode.dof_array import freeze, DOFArray, flatten, unflatten
+
+from grudge.discretization import DGDiscretizationWithBoundaries
+from grudge.symbolic.primitives import TracePair
 
 
 class EagerDGDiscretization(DGDiscretizationWithBoundaries):
@@ -115,19 +117,98 @@ class EagerDGDiscretization(DGDiscretizationWithBoundaries):
     def norm(self, vec, p=2):
         return self._norm(p)(arg=vec)
 
+    @memoize_method
+    def connected_ranks(self):
+        from meshmode.distributed import get_connected_partitions
+        return get_connected_partitions(self._volume_discr.mesh)
 
-def interior_trace_pair(discr, vec):
-    i = discr.interp("vol", "int_faces", vec)
+
+def interior_trace_pair(discrwb, vec):
+    i = discrwb.interp("vol", "int_faces", vec)
 
     if (isinstance(vec, np.ndarray)
             and vec.dtype.char == "O"
             and not isinstance(vec, DOFArray)):
         e = obj_array_vectorize(
-                lambda el: discr.opposite_face_connection()(el),
+                lambda el: discrwb.opposite_face_connection()(el),
                 i)
 
-    from grudge.symbolic.primitives import TracePair
     return TracePair("int_faces", i, e)
 
 
+class RankBoundaryCommunication:
+    base_tag = 1273
+
+    def __init__(self, discrwb, remote_rank, vol_field, tag=None):
+        self.tag = self.base_tag
+        if tag is not None:
+            self.tag += tag
+
+        self.discrwb = discrwb
+        self.array_context = vol_field.array_context
+        self.remote_btag = BTAG_PARTITION(remote_rank)
+
+        self.bdry_discr = discrwb.discr_from_dd(self.remote_btag)
+        self.local_dof_array = discrwb.interp("vol", self.remote_btag, vol_field)
+
+        local_data = self.array_context.to_numpy(flatten(self.local_dof_array))
+
+        comm = self.discrwb.mpi_communicator
+
+        self.send_req = comm.Isend(
+                local_data, remote_rank, tag=self.tag)
+
+        self.remote_data_host = np.empty_like(local_data)
+        self.recv_req = comm.Irecv(self.remote_data_host, remote_rank, self.tag)
+
+    def finish(self):
+        self.recv_req.Wait()
+
+        actx = self.array_context
+        remote_dof_array = unflatten(self.array_context, self.bdry_discr,
+                actx.from_numpy(self.remote_data_host))
+
+        bdry_conn = self.discrwb.get_distributed_boundary_swap_connection(
+                sym.as_dofdesc(sym.DTAG_BOUNDARY(self.remote_btag)))
+        swapped_remote_dof_array = bdry_conn(remote_dof_array)
+
+        self.send_req.Wait()
+
+        return TracePair(self.remote_btag, self.local_dof_array,
+                swapped_remote_dof_array)
+
+
+def _cross_rank_trace_pairs_scalar_field(discrwb, vec, tag=None):
+    rbcomms = [RankBoundaryCommunication(discrwb, remote_rank, vec, tag=tag)
+            for remote_rank in discrwb.connected_ranks()]
+    return [rbcomm.finish() for rbcomm in rbcomms]
+
+
+def cross_rank_trace_pairs(discrwb, vec, tag=None):
+    if (isinstance(vec, np.ndarray)
+            and vec.dtype.char == "O"
+            and not isinstance(vec, DOFArray)):
+
+        n, = vec.shape
+        result = {}
+        for ivec in range(n):
+            for rank_tpair in _cross_rank_trace_pairs_scalar_field(
+                    discrwb, vec[ivec]):
+                assert isinstance(rank_tpair.dd.domain_tag, sym.DTAG_BOUNDARY)
+                assert isinstance(rank_tpair.dd.domain_tag.tag, BTAG_PARTITION)
+                result[rank_tpair.dd.domain_tag.tag.part_nr, ivec] = rank_tpair
+
+        return [
+            TracePair(
+                dd=sym.as_dofdesc(sym.DTAG_BOUNDARY(BTAG_PARTITION(remote_rank))),
+                interior=make_obj_array([
+                    result[remote_rank, i].int for i in range(n)]),
+                exterior=make_obj_array([
+                    result[remote_rank, i].ext for i in range(n)])
+                )
+            for remote_rank in discrwb.connected_ranks()]
+    else:
+        return _cross_rank_trace_pairs_scalar_field(discrwb, vec, tag=tag)
+
+
 # vim: foldmethod=marker
-- 
GitLab