From 76ef26c93e187ce98fbca38b71048678c6e9a4fc Mon Sep 17 00:00:00 2001
From: Andreas Kloeckner <inform@tiker.net>
Date: Thu, 5 Nov 2009 12:48:59 -0500
Subject: [PATCH] Add pytools.mpiwrap. Migrate from boostmpi to mpi4py.

---
 .gitignore         |  1 +
 pytools/log.py     | 20 ++++++++------------
 pytools/mpi.py     |  4 ++--
 pytools/mpiwrap.py | 11 +++++++++++
 4 files changed, 22 insertions(+), 14 deletions(-)
 create mode 100644 pytools/mpiwrap.py

diff --git a/.gitignore b/.gitignore
index 890dd85..a4fd4de 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,3 +7,4 @@ build
 MANIFEST
 dist
 setuptools*egg
+setuptools.pth
diff --git a/pytools/log.py b/pytools/log.py
index a02b926..b8b7e9e 100644
--- a/pytools/log.py
+++ b/pytools/log.py
@@ -232,13 +232,13 @@ class LogManager(object):
     def __init__(self, filename=None, mode="r", mpi_comm=None, capture_warnings=True):
         """Initialize this log manager instance.
 
-        @arg filename: If given, the filename to which this log is bound.
+        :param filename: If given, the filename to which this log is bound.
           If this database exists, the current state is loaded from it.
-        @arg mode: One of "w", "r" for write, read. "w" assumes that the
+        :param mode: One of "w", "r" for write, read. "w" assumes that the
           database is initially empty.
-        @arg mpi_comm: An C{boostmpi} communicator. If given, logs are periodically
+        :arg mpi_comm: An C{mpi4py} communicator. If given, logs are periodically
           synchronized to the head node, which then writes them out to disk.
-        @arg capture_warnings: Tap the Python warnings facility and save warnings
+        :param capture_warnings: Tap the Python warnings facility and save warnings
           to the log file.
         """
 
@@ -306,15 +306,13 @@ class LogManager(object):
 
             # set globally unique run_id
             if self.is_parallel:
-                from boostmpi import broadcast
                 self.set_constant("unique_run_id",
-                        broadcast(self.mpi_comm, _get_unique_id(),
-                            root=self.head_rank))
+                        self.mpi_comm.bcast(_get_unique_id(), root=self.head_rank))
             else:
                 self.set_constant("unique_run_id", _get_unique_id())
 
             if self.is_parallel:
-                self.set_constant("rank_count", self.mpi_comm.size)
+                self.set_constant("rank_count", self.mpi_comm.Get_size())
             else:
                 self.set_constant("rank_count", 1)
         else:
@@ -780,9 +778,7 @@ class LogManager(object):
                 for qname in self.quantity_data.iterkeys())
 
         if self.mpi_comm is not None and self.have_nonlocal_watches:
-            from boostmpi import broadcast, gather
-
-            gathered_data = gather(self.mpi_comm, data_block, self.head_rank)
+            gathered_data = self.mpi_comm.gather(data_block, self.head_rank)
         else:
             gathered_data = [data_block]
 
@@ -808,7 +804,7 @@ class LogManager(object):
         self.next_watch_tick = self.tick_count + int(max(1, ticks_per_sec))
 
         if self.mpi_comm is not None and self.have_nonlocal_watches:
-            self.next_watch_tick = broadcast(self.mpi_comm,
+            self.next_watch_tick = self.mpi_comm.bcast(
                     self.next_watch_tick, self.head_rank)
 
 
diff --git a/pytools/mpi.py b/pytools/mpi.py
index c20c3cc..51ce7ff 100644
--- a/pytools/mpi.py
+++ b/pytools/mpi.py
@@ -1,6 +1,6 @@
 def in_mpi_relaunch():
     import os
-    return "BOOSTMPI_RUN_WITHIN_MPI" in os.environ
+    return "PYTOOLS_RUN_WITHIN_MPI" in os.environ
 
 def run_with_mpi_ranks(py_script, ranks, callable, *args, **kwargs):
     if in_mpi_relaunch():
@@ -8,7 +8,7 @@ def run_with_mpi_ranks(py_script, ranks, callable, *args, **kwargs):
     else:
         import sys, os
         newenv = os.environ.copy()
-        newenv["BOOSTMPI_RUN_WITHIN_MPI"] = "1"
+        newenv["PYTOOLS_RUN_WITHIN_MPI"] = "1"
 
         from subprocess import check_call
         check_call(["mpirun", "-np", str(ranks), 
diff --git a/pytools/mpiwrap.py b/pytools/mpiwrap.py
new file mode 100644
index 0000000..f6dfcda
--- /dev/null
+++ b/pytools/mpiwrap.py
@@ -0,0 +1,11 @@
+import mpi4py.rc
+
+mpi4py.rc.initialize = False
+
+import pytools.prefork
+pytools.prefork.enable_prefork()
+
+from mpi4py.MPI import *
+
+if Is_initialized():
+    raise RuntimeError("MPI already initialized before MPI wrapper import")
-- 
GitLab