Skip to content
Snippets Groups Projects
Commit 76ef26c9 authored by Andreas Klöckner's avatar Andreas Klöckner
Browse files

Add pytools.mpiwrap. Migrate from boostmpi to mpi4py.

parent 6275791b
No related branches found
No related tags found
No related merge requests found
......@@ -7,3 +7,4 @@ build
MANIFEST
dist
setuptools*egg
setuptools.pth
......@@ -232,13 +232,13 @@ class LogManager(object):
def __init__(self, filename=None, mode="r", mpi_comm=None, capture_warnings=True):
"""Initialize this log manager instance.
@arg filename: If given, the filename to which this log is bound.
:param filename: If given, the filename to which this log is bound.
If this database exists, the current state is loaded from it.
@arg mode: One of "w", "r" for write, read. "w" assumes that the
:param mode: One of "w", "r" for write, read. "w" assumes that the
database is initially empty.
@arg mpi_comm: An C{boostmpi} communicator. If given, logs are periodically
:arg mpi_comm: An C{mpi4py} communicator. If given, logs are periodically
synchronized to the head node, which then writes them out to disk.
@arg capture_warnings: Tap the Python warnings facility and save warnings
:param capture_warnings: Tap the Python warnings facility and save warnings
to the log file.
"""
......@@ -306,15 +306,13 @@ class LogManager(object):
# set globally unique run_id
if self.is_parallel:
from boostmpi import broadcast
self.set_constant("unique_run_id",
broadcast(self.mpi_comm, _get_unique_id(),
root=self.head_rank))
self.mpi_comm.bcast(_get_unique_id(), root=self.head_rank))
else:
self.set_constant("unique_run_id", _get_unique_id())
if self.is_parallel:
self.set_constant("rank_count", self.mpi_comm.size)
self.set_constant("rank_count", self.mpi_comm.Get_size())
else:
self.set_constant("rank_count", 1)
else:
......@@ -780,9 +778,7 @@ class LogManager(object):
for qname in self.quantity_data.iterkeys())
if self.mpi_comm is not None and self.have_nonlocal_watches:
from boostmpi import broadcast, gather
gathered_data = gather(self.mpi_comm, data_block, self.head_rank)
gathered_data = self.mpi_comm.gather(data_block, self.head_rank)
else:
gathered_data = [data_block]
......@@ -808,7 +804,7 @@ class LogManager(object):
self.next_watch_tick = self.tick_count + int(max(1, ticks_per_sec))
if self.mpi_comm is not None and self.have_nonlocal_watches:
self.next_watch_tick = broadcast(self.mpi_comm,
self.next_watch_tick = self.mpi_comm.bcast(
self.next_watch_tick, self.head_rank)
......
def in_mpi_relaunch():
import os
return "BOOSTMPI_RUN_WITHIN_MPI" in os.environ
return "PYTOOLS_RUN_WITHIN_MPI" in os.environ
def run_with_mpi_ranks(py_script, ranks, callable, *args, **kwargs):
if in_mpi_relaunch():
......@@ -8,7 +8,7 @@ def run_with_mpi_ranks(py_script, ranks, callable, *args, **kwargs):
else:
import sys, os
newenv = os.environ.copy()
newenv["BOOSTMPI_RUN_WITHIN_MPI"] = "1"
newenv["PYTOOLS_RUN_WITHIN_MPI"] = "1"
from subprocess import check_call
check_call(["mpirun", "-np", str(ranks),
......
import mpi4py.rc
mpi4py.rc.initialize = False
import pytools.prefork
pytools.prefork.enable_prefork()
from mpi4py.MPI import *
if Is_initialized():
raise RuntimeError("MPI already initialized before MPI wrapper import")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment