From e00ff50d12e6111d757e628e3d4eda6bee6d1cd4 Mon Sep 17 00:00:00 2001 From: ellis Date: Tue, 5 Sep 2017 17:07:57 -0500 Subject: [PATCH] Fix testmpi error --- test/test_meshmode.py | 3 +-- test/testmpi.py | 10 +++------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/test/test_meshmode.py b/test/test_meshmode.py index be66f28d..6b6c7b2e 100644 --- a/test/test_meshmode.py +++ b/test/test_meshmode.py @@ -57,8 +57,7 @@ def test_mpi_communication(num_partitions): import os newenv = os.environ.copy() newenv["PYTOOLS_RUN_WITHIN_MPI"] = "1" - check_call(["mpirun", "-np", str(num_ranks), - sys.executable, "test/testmpi.py", str(num_partitions)], + check_call(["mpirun", "-np", str(num_ranks), sys.executable, "testmpi.py"], env=newenv) diff --git a/test/testmpi.py b/test/testmpi.py index 7f84e501..511b1fe9 100644 --- a/test/testmpi.py +++ b/test/testmpi.py @@ -2,11 +2,12 @@ from __future__ import division, absolute_import, print_function import numpy as np -def mpi_comm(num_parts): +def mpi_comm(): from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() + num_parts = comm.Get_size() - 1 # This rank only partitions a mesh and sends them to their respective ranks. if rank == 0: @@ -125,9 +126,4 @@ def mpi_comm(num_parts): if __name__ == "__main__": - import sys - - assert len(sys.argv) == 2, 'Invalid number of arguments' - - num_parts = int(sys.argv[1]) - mpi_comm(num_parts) + mpi_comm() -- GitLab