diff --git a/test/test_meshmode.py b/test/test_meshmode.py index be66f28dab3ea82666e74c7475c4d675cb5c939f..6b6c7b2e1067931b55963672a327f9fa798a9324 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -57,8 +57,7 @@ def test_mpi_communication(num_partitions): import os newenv = os.environ.copy() newenv["PYTOOLS_RUN_WITHIN_MPI"] = "1" - check_call(["mpirun", "-np", str(num_ranks), - sys.executable, "test/testmpi.py", str(num_partitions)], + check_call(["mpirun", "-np", str(num_ranks), sys.executable, "testmpi.py"], env=newenv) diff --git a/test/testmpi.py b/test/testmpi.py index 7f84e5015a61743afbd9e63b08a03d314dc5a7b2..511b1fe98030ee226bce74ce19274d546b704be0 100644 --- a/test/testmpi.py +++ b/test/testmpi.py @@ -2,11 +2,12 @@ from __future__ import division, absolute_import, print_function import numpy as np -def mpi_comm(num_parts): +def mpi_comm(): from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() + num_parts = comm.Get_size() - 1 # This rank only partitions a mesh and sends them to their respective ranks. if rank == 0: @@ -125,9 +126,4 @@ def mpi_comm(num_parts): if __name__ == "__main__": - import sys - - assert len(sys.argv) == 2, 'Invalid number of arguments' - - num_parts = int(sys.argv[1]) - mpi_comm(num_parts) + mpi_comm()