From 89a5a86cd56a5f1643a040a201467c9a25cac9f9 Mon Sep 17 00:00:00 2001
From: Ellis <>
Date: Thu, 15 Feb 2018 16:10:40 -0600
Subject: [PATCH] grudge mpi communication

 grudge/                  | 53 ++++++++++++++---------
 grudge/symbolic/          | 65 ++++++++++++++++++++++++++++
 grudge/symbolic/ |  3 ++
 grudge/symbolic/mappers/  |  2 +-
 grudge/symbolic/         |  2 +
 test/       | 26 +++++------
 6 files changed, 115 insertions(+), 36 deletions(-)

diff --git a/grudge/ b/grudge/
index 7b27390..cdc7579 100644
--- a/grudge/
+++ b/grudge/
@@ -36,7 +36,8 @@ import logging
 logger = logging.getLogger(__name__)
-MPI_TAG_GRUDGE_DATA = 0x3700d3e
+# TODO: Maybe we should move this somewhere else.
+# MPI_TAG_GRUDGE_DATA = 0x3700d3e
 # {{{ exec mapper
@@ -251,27 +252,9 @@ class ExecutionMapper(mappers.Evaluator,
     def map_opposite_partition_face_swap(self, op, field_expr):
         assert op.dd_in == op.dd_out
         bdry_conn = self.discrwb.get_distributed_boundary_swap_connection(op.dd_in)
-        loc_bdry_vec = self.rec(field_expr).get(self.queue)
-        comm = self.discrwb.mpi_communicator
-        remote_rank = op.dd_in.domain_tag.part_nr
-        send_req = comm.Isend(loc_bdry_vec, remote_rank,
-                tag=MPI_TAG_GRUDGE_DATA)
-        recv_vec_host = np.empty_like(loc_bdry_vec)
-        comm.Recv(recv_vec_host, source=remote_rank, tag=MPI_TAG_GRUDGE_DATA)
-        send_req.wait()
-        recv_vec_dev = cl.array.to_device(self.queue, recv_vec_host)
-        shuffled_recv_vec = bdry_conn(self.queue, recv_vec_dev) \
-                .with_queue(self.queue)
-        return shuffled_recv_vec
+        remote_bdry_vec = self.rec(field_expr)  # swapped by RankDataSwapAssign
+        return bdry_conn(self.queue, remote_bdry_vec).with_queue(self.queue)
     def map_opposite_interior_face_swap(self, op, field_expr):
         return self.discrwb.opposite_face_connection()(
@@ -338,6 +321,34 @@ class ExecutionMapper(mappers.Evaluator,
     # {{{ instruction execution functions
+    def map_insn_rank_data_swap(self, insn):
+        local_data = self.rec(insn.field).get(self.queue)
+        comm = self.discrwb.mpi_communicator
+        send_req = comm.Isend(local_data, insn.i_remote_rank, tag=insn.tag)
+        remote_data_host = np.empty_like(local_data)
+        comm.Recv(remote_data_host, source=insn.i_remote_rank, tag=insn.tag)
+        send_req.wait()
+        remote_data = cl.array.to_device(self.queue, remote_data_host)
+        return [(, remote_data)], []
+        # class Future:
+        #     def is_ready(self):
+        #         return comm.improbe(source=insn.i_remote_rank, tag=insn.tag)
+        #
+        #     def __call__(self):
+        #         remote_data_host = np.empty_like(local_data)
+        #         comm.Recv(remote_data_host, source=insn.i_remote_rank, tag=insn.tag)
+        #         send_req.wait()
+        #
+        #         remote_data = cl.array.to_device(queue, remote_data_host)
+        #         return [(, remote_data)], []
+        #
+        # return [], [Future()]
     def map_insn_loopy_kernel(self, insn):
         kwargs = {}
         kdescr = insn.kernel_descriptor
diff --git a/grudge/symbolic/ b/grudge/symbolic/
index c555cea..450b3cd 100644
--- a/grudge/symbolic/
+++ b/grudge/symbolic/
@@ -198,6 +198,50 @@ class Assign(AssignBase):
     mapper_method = intern("map_insn_assign")
+class RankDataSwapAssign(Instruction):
+    """
+    .. attribute:: name
+    .. attribute:: field
+    .. attribute:: i_remote_rank
+        The number of the remote rank that this instruction swaps data with.
+    .. attribute:: mpi_tag_offset
+        A tag offset for mpi that should be unique for each instance within
+        a particular rank.
+    .. attribute:: dd_out
+    .. attribute:: comment
+    """
+    # TODO: Is this number ok? We probably want it to be global.
+    MPI_TAG_GRUDGE_DATA = 0x3700d3e
+    def __init__(self, name, field, op):
+ = name
+        self.field = field
+        self.i_remote_rank = op.i_remote_part
+        self.dd_out = op.dd_out
+        self.tag = self.MPI_TAG_GRUDGE_DATA + op.mpi_tag_offset
+        self.comment = "Swap data with rank %02d" % self.i_remote_rank
+    @memoize_method
+    def get_assignees(self):
+        return set([])
+    @memoize_method
+    def get_dependencies(self):
+        return _make_dep_mapper(include_subscripts=False)(self.field)
+    def __str__(self):
+        return ("{\n"
+                "    /* %s */\n"
+                "    %s <- %s\n"
+                "}\n" % (self.comment,, self.field))
+    mapper_method = intern("map_insn_rank_data_swap")
 class ToDiscretizationScopedAssign(Assign):
     scope_indicator = "(to discr)-"
@@ -933,6 +977,9 @@ class ToLoopyInstructionMapper(object):
+    def map_insn_rank_data_swap(self, insn):
+        return insn
     def map_insn_assign_to_discr_scoped(self, insn):
         return insn
@@ -1122,6 +1169,8 @@ class OperatorCompiler(mappers.IdentityMapper):
     def map_operator_binding(self, expr, codegen_state, name_hint=None):
         if isinstance(expr.op, sym.RefDiffOperatorBase):
             return self.map_ref_diff_op_binding(expr, codegen_state)
+        elif isinstance(expr.op, sym.OppositePartitionFaceSwap):
+            return self.map_rank_data_swap_binding(expr, codegen_state)
             # make sure operator assignments stand alone and don't get muddled
             # up in vector math
@@ -1180,6 +1229,22 @@ class OperatorCompiler(mappers.IdentityMapper):
             return self.expr_to_var[expr]
+    def map_rank_data_swap_binding(self, expr, codegen_state):
+        try:
+            return self.expr_to_var[expr]
+        except KeyError:
+            field = self.rec(expr.field, codegen_state)
+            name = self.name_gen("raw_rank%02d_bdry_data" % expr.op.i_remote_part)
+            field_insn = RankDataSwapAssign(name=name, field=field, op=expr.op)
+            codegen_state.get_code_list(self).append(field_insn)
+            field_var = Variable(
+            # TODO: Do I need this?
+            # self.expr_to_var[field] = field_var
+            self.expr_to_var[expr] = self.assign_to_new_var(codegen_state,
+                                                            expr.op(field_var),
+                                                            prefix="other")
+            return self.expr_to_var[expr]
     # }}}
 # }}}
diff --git a/grudge/symbolic/ b/grudge/symbolic/
index 7e1de60..92be126 100644
--- a/grudge/symbolic/
+++ b/grudge/symbolic/
@@ -201,6 +201,9 @@ class DOFDescInferenceMapper(RecursiveMapper, CSECachingMapperMixin):
                 for name, expr in zip(insn.names, insn.exprs)
+    def map_insn_rank_data_swap(self, insn):
+        return [(, insn.dd_out)]
     map_insn_assign_to_discr_scoped = map_insn_assign
     def map_insn_diff_batch_assign(self, insn):
diff --git a/grudge/symbolic/mappers/ b/grudge/symbolic/mappers/
index 2ddd6f5..9db1ab3 100644
--- a/grudge/symbolic/mappers/
+++ b/grudge/symbolic/mappers/
@@ -661,7 +661,7 @@ class StringifyMapper(pymbolic.mapper.stringifier.StringifyMapper):
         elif dd.domain_tag is FACE_RESTR_INTERIOR:
             result = "int_faces"
         elif isinstance(dd.domain_tag, BTAG_PARTITION):
-            result = "rank%d_faces" % dd.domain_tag.part_nr
+            result = "part%d_faces" % dd.domain_tag.part_nr
             result = fmt(dd.domain_tag)
diff --git a/grudge/symbolic/ b/grudge/symbolic/
index 294c437..7cdb3d2 100644
--- a/grudge/symbolic/
+++ b/grudge/symbolic/
@@ -427,6 +427,8 @@ class OppositePartitionFaceSwap(Operator):
             raise ValueError("dd_out and dd_in must be identical")
         self.i_remote_part = self.dd_in.domain_tag.part_nr
+        # FIXME: We should have a unique offset for each instance on a particular rank
+        self.mpi_tag_offset = 0
     mapper_method = intern("map_opposite_partition_face_swap")
diff --git a/test/ b/test/
index 208de1a..3bf012f 100644
--- a/test/
+++ b/test/
@@ -36,7 +36,7 @@ from grudge import sym, bind, DGDiscretizationWithBoundaries
 from grudge.shortcuts import set_up_rk4
-def simple_mpi_communication_entrypoint(order):
+def simple_mpi_communication_entrypoint():
     cl_ctx = cl.create_some_context()
     queue = cl.CommandQueue(cl_ctx)
     from meshmode.distributed import MPIMeshDistributor
@@ -53,19 +53,17 @@ def simple_mpi_communication_entrypoint(order):
-        # This gives [0, 0, 0, 1, 0, 1, 1, 1]
-        # from pymetis import part_graph
-        # _, p = part_graph(num_parts,
-        #                   xadj=mesh.nodal_adjacency.neighbors_starts.tolist(),
-        #                   adjncy=mesh.nodal_adjacency.neighbors.tolist())
-        # part_per_element = np.array(p)
-        part_per_element = np.array([0, 0, 0, 1, 0, 1, 1, 1])
+        from pymetis import part_graph
+        _, p = part_graph(num_parts,
+                          xadj=mesh.nodal_adjacency.neighbors_starts.tolist(),
+                          adjncy=mesh.nodal_adjacency.neighbors.tolist())
+        part_per_element = np.array(p)
         local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
         local_mesh = mesh_dist.receive_mesh_part()
-    vol_discr = DGDiscretizationWithBoundaries(cl_ctx, local_mesh, order=order,
+    vol_discr = DGDiscretizationWithBoundaries(cl_ctx, local_mesh, order=5,
     sym_x = sym.nodes(local_mesh.dim)
@@ -87,6 +85,9 @@ def simple_mpi_communication_entrypoint(order):
             ) - (sym_all_faces_func - sym_bdry_faces_func)
+    print(bound_face_swap)
+    # 1/0
     hopefully_zero = bound_face_swap(queue, myfunc=myfunc)
     import numpy.linalg as la
     error = la.norm(hopefully_zero.get())
@@ -227,8 +228,7 @@ def test_mpi(num_ranks):
 @pytest.mark.parametrize("num_ranks", [2])
-@pytest.mark.parametrize("order", [2])
-def test_simple_mpi(num_ranks, order):
+def test_simple_mpi(num_ranks):
     from subprocess import check_call
@@ -236,7 +236,6 @@ def test_simple_mpi(num_ranks, order):
     newenv = os.environ.copy()
     newenv["RUN_WITHIN_MPI"] = "1"
-    newenv["order"] = str(order)
         "mpiexec", "-np", str(num_ranks), "-x", "RUN_WITHIN_MPI",
         sys.executable, __file__],
@@ -250,8 +249,7 @@ if __name__ == "__main__":
         if "TEST_MPI_COMMUNICATION" in os.environ:
         elif "TEST_SIMPLE_MPI_COMMUNICATION" in os.environ:
-            order = int(os.environ["order"])
-            simple_mpi_communication_entrypoint(order)
+            simple_mpi_communication_entrypoint()
         import sys
         if len(sys.argv) > 1: