diff --git a/test/test_distributed.py b/test/test_distributed.py index 95f220b7e6ad715f933f8e55d2fbd8cd0bcf150d..232c408a07d81ca9b8ea3b7c76b2ed6d8bead96d 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -32,23 +32,33 @@ import os # {{{ mpi test infrastructure -def run_test_with_mpi(num_ranks, f, *args): +def run_test_with_mpi(num_ranks, f, *args, extra_env_vars=None): import pytest pytest.importorskip("mpi4py") + if extra_env_vars is None: + extra_env_vars = {} + from pickle import dumps from base64 import b64encode - invocation_info = b64encode(dumps((f, args))).decode() from subprocess import check_call + env_vars = { + "RUN_WITHIN_MPI": "1", + "INVOCATION_INFO": b64encode(dumps((f, args))).decode(), + } + env_vars.update(extra_env_vars) + # NOTE: CI uses OpenMPI; -x to pass env vars. MPICH uses -env check_call([ "mpiexec", "-np", str(num_ranks), - "-x", "RUN_WITHIN_MPI=1", "--oversubscribe", - "-x", f"INVOCATION_INFO={invocation_info}", - sys.executable, "-m", "mpi4py", __file__]) + ] + [ + item + for env_name, env_val in env_vars.items() + for item in ["-x", f"{env_name}={env_val}"] + ] + [sys.executable, "-m", "mpi4py", __file__]) def run_test_with_mpi_inner():