diff --git a/pytools/mpi.py b/pytools/mpi.py
index d67106f12f663c33ffdae60c9d6aaa29c5ac37f5..c20c3cc17ba8b4f4d4c4aee23bb26cfc44e2a95e 100644
--- a/pytools/mpi.py
+++ b/pytools/mpi.py
@@ -1,9 +1,12 @@
-def run_with_mpi_ranks(py_script, ranks, callable, *args, **kwargs):
+def in_mpi_relaunch():
     import os
-    if "BOOSTMPI_RUN_WITHIN_MPI" in os.environ:
+    return "BOOSTMPI_RUN_WITHIN_MPI" in os.environ
+
+def run_with_mpi_ranks(py_script, ranks, callable, *args, **kwargs):
+    if in_mpi_relaunch():
         callable(*args, **kwargs)
     else:
-        import sys
+        import sys, os
         newenv = os.environ.copy()
         newenv["BOOSTMPI_RUN_WITHIN_MPI"] = "1"