diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 05def1d257fc5f2afd21ea9a2fcba033c3aac7f6..db14dd13a81e1cb21079dab76ed0498b316f5c4d 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -36,7 +36,7 @@ from grudge import sym, bind, Discretization from grudge.shortcuts import set_up_rk4 -def boundary_communication_entrypoint(): +def simple_communication_entrypoint(): cl_ctx = cl.create_some_context() queue = cl.CommandQueue(cl_ctx) from meshmode.distributed import MPIMeshDistributor @@ -208,19 +208,31 @@ def mpi_communication_entrypoint(): # {{{ MPI test pytest entrypoint @pytest.mark.mpi -@pytest.mark.parametrize("testcase", [ - # "MPI_COMMUNICATION", - "BOUNDARY_COMMUNICATION" - ]) @pytest.mark.parametrize("num_ranks", [2]) -def test_mpi(testcase, num_ranks): +def test_mpi(num_ranks): pytest.importorskip("mpi4py") from subprocess import check_call import sys newenv = os.environ.copy() newenv["RUN_WITHIN_MPI"] = "1" - newenv[testcase] = "1" + newenv["TEST_MPI_COMMUNICATION"] = "1" + check_call([ + "mpiexec", "-np", str(num_ranks), "-x", "RUN_WITHIN_MPI", + sys.executable, __file__], + env=newenv) + + +@pytest.mark.mpi +@pytest.mark.parametrize("num_ranks", [2]) +def test_simple_mpi(num_ranks): + pytest.importorskip("mpi4py") + + from subprocess import check_call + import sys + newenv = os.environ.copy() + newenv["RUN_WITHIN_MPI"] = "1" + newenv["TEST_SIMPLE_COMMUNICATION"] = "1" check_call([ "mpiexec", "-np", str(num_ranks), "-x", "RUN_WITHIN_MPI", sys.executable, __file__], @@ -230,10 +242,11 @@ def test_mpi(testcase, num_ranks): if __name__ == "__main__": - if "MPI_COMMUNICATION" in os.environ: - mpi_communication_entrypoint() - elif "BOUNDARY_COMMUNICATION" in os.environ: - boundary_communication_entrypoint() + if "RUN_WITHIN_MPI" in os.environ: + if "TEST_MPI_COMMUNICATION" in os.environ: + mpi_communication_entrypoint() + elif "TEST_SIMPLE_COMMUNICATION" in os.environ: + simple_communication_entrypoint() else: import sys if len(sys.argv) > 1: