From 53b124e3c40cbf3da6d20bb6f0ea7629ffa821a0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 12 Jan 2023 20:13:13 -0600 Subject: [PATCH] run_test_with_mpi: Allow passing extra_env_vars Co-authored-by: Matthew Smith --- test/test_distributed.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/test/test_distributed.py b/test/test_distributed.py index 95f220b..232c408 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -32,23 +32,33 @@ import os # {{{ mpi test infrastructure -def run_test_with_mpi(num_ranks, f, *args): +def run_test_with_mpi(num_ranks, f, *args, extra_env_vars=None): import pytest pytest.importorskip("mpi4py") + if extra_env_vars is None: + extra_env_vars = {} + from pickle import dumps from base64 import b64encode - invocation_info = b64encode(dumps((f, args))).decode() from subprocess import check_call + env_vars = { + "RUN_WITHIN_MPI": "1", + "INVOCATION_INFO": b64encode(dumps((f, args))).decode(), + } + env_vars.update(extra_env_vars) + # NOTE: CI uses OpenMPI; -x to pass env vars. MPICH uses -env check_call([ "mpiexec", "-np", str(num_ranks), - "-x", "RUN_WITHIN_MPI=1", "--oversubscribe", - "-x", f"INVOCATION_INFO={invocation_info}", - sys.executable, "-m", "mpi4py", __file__]) + ] + [ + item + for env_name, env_val in env_vars.items() + for item in ["-x", f"{env_name}={env_val}"] + ] + [sys.executable, "-m", "mpi4py", __file__]) def run_test_with_mpi_inner(): -- GitLab