diff --git a/pytools/mpi.py b/pytools/mpi.py index 51ce7ff3c0bb203ae91e41745551a6c7db5cd07a..a575c0752d46d6322717be316bd5dbfc4e46b252 100644 --- a/pytools/mpi.py +++ b/pytools/mpi.py @@ -1,16 +1,28 @@ -def in_mpi_relaunch(): - import os - return "PYTOOLS_RUN_WITHIN_MPI" in os.environ +def check_for_mpi_relaunch(argv): + if argv[1] != "--mpi-relaunch": + return + + from pickle import loads + f, args, kwargs = loads(argv[2]) + + f(*args, **kwargs) + import sys + sys.exit() -def run_with_mpi_ranks(py_script, ranks, callable, *args, **kwargs): - if in_mpi_relaunch(): - callable(*args, **kwargs) - else: - import sys, os - newenv = os.environ.copy() - newenv["PYTOOLS_RUN_WITHIN_MPI"] = "1" - from subprocess import check_call - check_call(["mpirun", "-np", str(ranks), - sys.executable, py_script], env=newenv) +def run_with_mpi_ranks(py_script, ranks, callable, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + import sys + import os + newenv = os.environ.copy() + newenv["PYTOOLS_RUN_WITHIN_MPI"] = "1" + + from pickle import dumps + callable_and_args = dumps((callable, args, kwargs)) + from subprocess import check_call + check_call(["mpirun", "-np", str(ranks), + sys.executable, py_script, "--mpi-relaunch", callable_and_args], + env=newenv)