From 05fd17e15e9e4c74f02c86b40394abb346b28cd3 Mon Sep 17 00:00:00 2001
From: Ellis <eshoag@illinois.edu>
Date: Sat, 21 Oct 2017 18:29:25 -0500
Subject: [PATCH] new tests for mpi communication

---
 grudge/symbolic/operators.py   |   6 +-
 test/test_mpi_communication.py | 104 +++++++++++++++++++++++++++++++++
 2 files changed, 107 insertions(+), 3 deletions(-)
 create mode 100644 test/test_mpi_communication.py

diff --git a/grudge/symbolic/operators.py b/grudge/symbolic/operators.py
index 05a23ad..a1d0f21 100644
--- a/grudge/symbolic/operators.py
+++ b/grudge/symbolic/operators.py
@@ -387,14 +387,14 @@ class OppositeRankFaceSwap(Operator):
         if dd_in is None:
             # FIXME: What is FRESTR_INTERIOR_FACES?
             dd_in = sym.DOFDesc(sym.FRESTR_INTERIOR_FACES)
-            # dd_in = sym.DOFDesc(sym.BTAG_PARTITION)
+            # dd_in = sym.DOFDesc(BTAG_PARTITION)
         if dd_out is None:
             dd_out = dd_in
 
         # if dd_in.domain_tag is not BTAG_PARTITION:
         #     raise ValueError("dd_in must be a rank boundary faces domain")
-        # if dd_out != dd_in:
-        #     raise ValueError("dd_out and dd_in must be identical")
+        if dd_out != dd_in:
+            raise ValueError("dd_out and dd_in must be identical")
 
         super(OppositeRankFaceSwap, self).__init__(dd_in, dd_out)
 
diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py
new file mode 100644
index 0000000..a343beb
--- /dev/null
+++ b/test/test_mpi_communication.py
@@ -0,0 +1,104 @@
+from __future__ import division, absolute_import, print_function
+
+__copyright__ = """
+Copyright (C) 2017 Ellis Hoag
+Copyright (C) 2017 Andreas Kloeckner
+"""
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+import pytest
+import os
+
+import logging
+logger = logging.getLogger(__name__)
+
+import numpy as np
+
+
+def mpi_communication_entrypoint():
+    from meshmode.distributed import MPIMeshDistributor, MPIBoundaryCommunicator
+
+    from mpi4py import MPI
+    comm = MPI.COMM_WORLD
+    rank = comm.Get_rank()
+    num_parts = comm.Get_size()
+
+    mesh_dist = MPIMeshDistributor(comm)
+
+    if mesh_dist.is_mananger_rank():
+        np.random.seed(42)
+        from meshmode.mesh.generation import generate_warped_rect_mesh
+        meshes = [generate_warped_rect_mesh(3, order=4, n=4) for _ in range(2)]
+
+        from meshmode.mesh.processing import merge_disjoint_meshes
+        mesh = merge_disjoint_meshes(meshes)
+
+        part_per_element = np.random.randint(num_parts, size=mesh.nelements)
+
+        local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
+    else:
+        local_mesh = mesh_dist.receive_mesh_part()
+
+    from meshmode.discretization.poly_element\
+                    import PolynomialWarpAndBlendGroupFactory
+    group_factory = PolynomialWarpAndBlendGroupFactory(4)
+    import pyopencl as cl
+    cl_ctx = cl.create_some_context()
+    queue = cl.CommandQueue(cl_ctx)
+
+    from meshmode.discretization import Discretization
+    vol_discr = Discretization(cl_ctx, local_mesh, group_factory)
+
+    logger.debug("Rank %d exiting", rank)
+
+
+# {{{ MPI test pytest entrypoint
+
+@pytest.mark.mpi
+@pytest.mark.parametrize("num_partitions", [3, 4])
+def test_mpi_communication(num_partitions):
+    pytest.importorskip("mpi4py")
+
+    num_ranks = num_partitions
+    from subprocess import check_call
+    import sys
+    newenv = os.environ.copy()
+    newenv["RUN_WITHIN_MPI"] = "1"
+    check_call([
+        "mpiexec", "-np", str(num_ranks), "-x", "RUN_WITHIN_MPI",
+        sys.executable, __file__],
+        env=newenv)
+
+# }}}
+
+if __name__ == "__main__":
+    if "RUN_WITHIN_MPI" in os.environ:
+        mpi_communication_entrypoint()
+    else:
+        import sys
+        if len(sys.argv) > 1:
+            exec(sys.argv[1])
+        else:
+            from py.test.cmdline import main
+            main([__file__])
+
+# vim: fdm=marker
-- 
GitLab