From 89a5a86cd56a5f1643a040a201467c9a25cac9f9 Mon Sep 17 00:00:00 2001
From: Ellis <eshoag@illinois.edu>
Date: Thu, 15 Feb 2018 16:10:40 -0600
Subject: [PATCH] grudge mpi communication

---
 grudge/execution.py                  | 53 ++++++++++++++---------
 grudge/symbolic/compiler.py          | 65 ++++++++++++++++++++++++++++
 grudge/symbolic/dofdesc_inference.py |  3 ++
 grudge/symbolic/mappers/__init__.py  |  2 +-
 grudge/symbolic/operators.py         |  2 +
 test/test_mpi_communication.py       | 26 +++++------
 6 files changed, 115 insertions(+), 36 deletions(-)

diff --git a/grudge/execution.py b/grudge/execution.py
index 7b27390..cdc7579 100644
--- a/grudge/execution.py
+++ b/grudge/execution.py
@@ -36,7 +36,8 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-MPI_TAG_GRUDGE_DATA = 0x3700d3e
+# TODO: Maybe we should move this somewhere else.
+# MPI_TAG_GRUDGE_DATA = 0x3700d3e
 
 
 # {{{ exec mapper
@@ -251,27 +252,9 @@ class ExecutionMapper(mappers.Evaluator,
 
     def map_opposite_partition_face_swap(self, op, field_expr):
         assert op.dd_in == op.dd_out
-
         bdry_conn = self.discrwb.get_distributed_boundary_swap_connection(op.dd_in)
-        loc_bdry_vec = self.rec(field_expr).get(self.queue)
-
-        comm = self.discrwb.mpi_communicator
-
-        remote_rank = op.dd_in.domain_tag.part_nr
-
-        send_req = comm.Isend(loc_bdry_vec, remote_rank,
-                tag=MPI_TAG_GRUDGE_DATA)
-
-        recv_vec_host = np.empty_like(loc_bdry_vec)
-        comm.Recv(recv_vec_host, source=remote_rank, tag=MPI_TAG_GRUDGE_DATA)
-        send_req.wait()
-
-        recv_vec_dev = cl.array.to_device(self.queue, recv_vec_host)
-
-        shuffled_recv_vec = bdry_conn(self.queue, recv_vec_dev) \
-                .with_queue(self.queue)
-
-        return shuffled_recv_vec
+        remote_bdry_vec = self.rec(field_expr)  # swapped by RankDataSwapAssign
+        return bdry_conn(self.queue, remote_bdry_vec).with_queue(self.queue)
 
     def map_opposite_interior_face_swap(self, op, field_expr):
         return self.discrwb.opposite_face_connection()(
@@ -338,6 +321,34 @@ class ExecutionMapper(mappers.Evaluator,
 
     # {{{ instruction execution functions
 
+    def map_insn_rank_data_swap(self, insn):
+        local_data = self.rec(insn.field).get(self.queue)
+        comm = self.discrwb.mpi_communicator
+
+        send_req = comm.Isend(local_data, insn.i_remote_rank, tag=insn.tag)
+
+        remote_data_host = np.empty_like(local_data)
+        comm.Recv(remote_data_host, source=insn.i_remote_rank, tag=insn.tag)
+        send_req.wait()
+        remote_data = cl.array.to_device(self.queue, remote_data_host)
+
+        return [(insn.name, remote_data)], []
+
+        # class Future:
+        #     def is_ready(self):
+        #         return comm.improbe(source=insn.i_remote_rank, tag=insn.tag)
+        #
+        #     def __call__(self):
+        #         remote_data_host = np.empty_like(local_data)
+        #         comm.Recv(remote_data_host, source=insn.i_remote_rank, tag=insn.tag)
+        #         send_req.wait()
+        #
+        #         remote_data = cl.array.to_device(queue, remote_data_host)
+        #         return [(insn.name, remote_data)], []
+        #
+        # return [], [Future()]
+
+
     def map_insn_loopy_kernel(self, insn):
         kwargs = {}
         kdescr = insn.kernel_descriptor
diff --git a/grudge/symbolic/compiler.py b/grudge/symbolic/compiler.py
index c555cea..450b3cd 100644
--- a/grudge/symbolic/compiler.py
+++ b/grudge/symbolic/compiler.py
@@ -198,6 +198,50 @@ class Assign(AssignBase):
     mapper_method = intern("map_insn_assign")
 
 
+class RankDataSwapAssign(Instruction):
+    """
+    .. attribute:: name
+    .. attribute:: field
+    .. attribute:: i_remote_rank
+
+        The number of the remote rank that this instruction swaps data with.
+
+    .. attribute:: mpi_tag_offset
+
+        A tag offset for mpi that should be unique for each instance within
+        a particular rank.
+
+    .. attribute:: dd_out
+    .. attribute:: comment
+    """
+    # TODO: Is this number ok? We probably want it to be global.
+    MPI_TAG_GRUDGE_DATA = 0x3700d3e
+
+    def __init__(self, name, field, op):
+        self.name = name
+        self.field = field
+        self.i_remote_rank = op.i_remote_part
+        self.dd_out = op.dd_out
+        self.tag = self.MPI_TAG_GRUDGE_DATA + op.mpi_tag_offset
+        self.comment = "Swap data with rank %02d" % self.i_remote_rank
+
+    @memoize_method
+    def get_assignees(self):
+        return set([self.name])
+
+    @memoize_method
+    def get_dependencies(self):
+        return _make_dep_mapper(include_subscripts=False)(self.field)
+
+    def __str__(self):
+        return ("{\n"
+                "    /* %s */\n"
+                "    %s <- %s\n"
+                "}\n" % (self.comment, self.name, self.field))
+
+    mapper_method = intern("map_insn_rank_data_swap")
+
+
 class ToDiscretizationScopedAssign(Assign):
     scope_indicator = "(to discr)-"
 
@@ -933,6 +977,9 @@ class ToLoopyInstructionMapper(object):
                 governing_dd=governing_dd)
             )
 
+    def map_insn_rank_data_swap(self, insn):
+        return insn
+
     def map_insn_assign_to_discr_scoped(self, insn):
         return insn
 
@@ -1122,6 +1169,8 @@ class OperatorCompiler(mappers.IdentityMapper):
     def map_operator_binding(self, expr, codegen_state, name_hint=None):
         if isinstance(expr.op, sym.RefDiffOperatorBase):
             return self.map_ref_diff_op_binding(expr, codegen_state)
+        elif isinstance(expr.op, sym.OppositePartitionFaceSwap):
+            return self.map_rank_data_swap_binding(expr, codegen_state)
         else:
             # make sure operator assignments stand alone and don't get muddled
             # up in vector math
@@ -1180,6 +1229,22 @@ class OperatorCompiler(mappers.IdentityMapper):
 
             return self.expr_to_var[expr]
 
+    def map_rank_data_swap_binding(self, expr, codegen_state):
+        try:
+            return self.expr_to_var[expr]
+        except KeyError:
+            field = self.rec(expr.field, codegen_state)
+            name = self.name_gen("raw_rank%02d_bdry_data" % expr.op.i_remote_part)
+            field_insn = RankDataSwapAssign(name=name, field=field, op=expr.op)
+            codegen_state.get_code_list(self).append(field_insn)
+            field_var = Variable(field_insn.name)
+            # TODO: Do I need this?
+            # self.expr_to_var[field] = field_var
+            self.expr_to_var[expr] = self.assign_to_new_var(codegen_state,
+                                                            expr.op(field_var),
+                                                            prefix="other")
+            return self.expr_to_var[expr]
+
     # }}}
 
 # }}}
diff --git a/grudge/symbolic/dofdesc_inference.py b/grudge/symbolic/dofdesc_inference.py
index 7e1de60..92be126 100644
--- a/grudge/symbolic/dofdesc_inference.py
+++ b/grudge/symbolic/dofdesc_inference.py
@@ -201,6 +201,9 @@ class DOFDescInferenceMapper(RecursiveMapper, CSECachingMapperMixin):
                 for name, expr in zip(insn.names, insn.exprs)
                 ]
 
+    def map_insn_rank_data_swap(self, insn):
+        return [(insn.name, insn.dd_out)]
+
     map_insn_assign_to_discr_scoped = map_insn_assign
 
     def map_insn_diff_batch_assign(self, insn):
diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py
index 2ddd6f5..9db1ab3 100644
--- a/grudge/symbolic/mappers/__init__.py
+++ b/grudge/symbolic/mappers/__init__.py
@@ -661,7 +661,7 @@ class StringifyMapper(pymbolic.mapper.stringifier.StringifyMapper):
         elif dd.domain_tag is FACE_RESTR_INTERIOR:
             result = "int_faces"
         elif isinstance(dd.domain_tag, BTAG_PARTITION):
-            result = "rank%d_faces" % dd.domain_tag.part_nr
+            result = "part%d_faces" % dd.domain_tag.part_nr
         else:
             result = fmt(dd.domain_tag)
 
diff --git a/grudge/symbolic/operators.py b/grudge/symbolic/operators.py
index 294c437..7cdb3d2 100644
--- a/grudge/symbolic/operators.py
+++ b/grudge/symbolic/operators.py
@@ -427,6 +427,8 @@ class OppositePartitionFaceSwap(Operator):
             raise ValueError("dd_out and dd_in must be identical")
 
         self.i_remote_part = self.dd_in.domain_tag.part_nr
+        # FIXME: We should have a unique offset for each instance on a particular rank
+        self.mpi_tag_offset = 0
 
     mapper_method = intern("map_opposite_partition_face_swap")
 
diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py
index 208de1a..3bf012f 100644
--- a/test/test_mpi_communication.py
+++ b/test/test_mpi_communication.py
@@ -36,7 +36,7 @@ from grudge import sym, bind, DGDiscretizationWithBoundaries
 from grudge.shortcuts import set_up_rk4
 
 
-def simple_mpi_communication_entrypoint(order):
+def simple_mpi_communication_entrypoint():
     cl_ctx = cl.create_some_context()
     queue = cl.CommandQueue(cl_ctx)
     from meshmode.distributed import MPIMeshDistributor
@@ -53,19 +53,17 @@ def simple_mpi_communication_entrypoint(order):
                                           b=(1,)*2,
                                           n=(3,)*2)
 
-        # This gives [0, 0, 0, 1, 0, 1, 1, 1]
-        # from pymetis import part_graph
-        # _, p = part_graph(num_parts,
-        #                   xadj=mesh.nodal_adjacency.neighbors_starts.tolist(),
-        #                   adjncy=mesh.nodal_adjacency.neighbors.tolist())
-        # part_per_element = np.array(p)
-        part_per_element = np.array([0, 0, 0, 1, 0, 1, 1, 1])
+        from pymetis import part_graph
+        _, p = part_graph(num_parts,
+                          xadj=mesh.nodal_adjacency.neighbors_starts.tolist(),
+                          adjncy=mesh.nodal_adjacency.neighbors.tolist())
+        part_per_element = np.array(p)
 
         local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
     else:
         local_mesh = mesh_dist.receive_mesh_part()
 
-    vol_discr = DGDiscretizationWithBoundaries(cl_ctx, local_mesh, order=order,
+    vol_discr = DGDiscretizationWithBoundaries(cl_ctx, local_mesh, order=5,
             mpi_communicator=comm)
 
     sym_x = sym.nodes(local_mesh.dim)
@@ -87,6 +85,9 @@ def simple_mpi_communication_entrypoint(order):
             ) - (sym_all_faces_func - sym_bdry_faces_func)
             )
 
+    print(bound_face_swap)
+    # 1/0
+
     hopefully_zero = bound_face_swap(queue, myfunc=myfunc)
     import numpy.linalg as la
     error = la.norm(hopefully_zero.get())
@@ -227,8 +228,7 @@ def test_mpi(num_ranks):
 
 @pytest.mark.mpi
 @pytest.mark.parametrize("num_ranks", [2])
-@pytest.mark.parametrize("order", [2])
-def test_simple_mpi(num_ranks, order):
+def test_simple_mpi(num_ranks):
     pytest.importorskip("mpi4py")
 
     from subprocess import check_call
@@ -236,7 +236,6 @@ def test_simple_mpi(num_ranks, order):
     newenv = os.environ.copy()
     newenv["RUN_WITHIN_MPI"] = "1"
     newenv["TEST_SIMPLE_MPI_COMMUNICATION"] = "1"
-    newenv["order"] = str(order)
     check_call([
         "mpiexec", "-np", str(num_ranks), "-x", "RUN_WITHIN_MPI",
         sys.executable, __file__],
@@ -250,8 +249,7 @@ if __name__ == "__main__":
         if "TEST_MPI_COMMUNICATION" in os.environ:
             mpi_communication_entrypoint()
         elif "TEST_SIMPLE_MPI_COMMUNICATION" in os.environ:
-            order = int(os.environ["order"])
-            simple_mpi_communication_entrypoint(order)
+            simple_mpi_communication_entrypoint()
     else:
         import sys
         if len(sys.argv) > 1:
-- 
GitLab