diff --git a/pytools/prefork.py b/pytools/prefork.py index a81c987d67e0b556cd7479251be48b334567ffd6..9cb90541ccc9a0ba194eb73cec87fc5cbef1aeeb 100644 --- a/pytools/prefork.py +++ b/pytools/prefork.py @@ -8,51 +8,64 @@ Since none of this is MPI-specific, it got parked in pytools. from __future__ import absolute_import - - - class ExecError(OSError): pass +class DirectForker(object): + def __init__(self): + self.apids = {} + self.count = 0 - -class DirectForker: - @staticmethod - def call(cmdline, cwd=None): + def call(self, cmdline, cwd=None): from subprocess import call + try: return call(cmdline, cwd=cwd) except OSError as e: raise ExecError("error invoking '%s': %s" - % ( " ".join(cmdline), e)) + % (" ".join(cmdline), e)) + + def call_async(self, cmdline, cwd=None): + from subprocess import Popen - @staticmethod - def call_capture_stdout(cmdline, cwd=None): - from subprocess import Popen, PIPE try: - return Popen(cmdline, cwd=cwd, stdin=PIPE, stdout=PIPE, stderr=PIPE).communicate()[0] + self.count += 1 + + proc = Popen(cmdline, cwd=cwd) + self.apids[self.count] = proc + + return self.count except OSError as e: - raise ExecError("error invoking '%s': %s" - % ( " ".join(cmdline), e)) + raise ExecError("error invoking '%s': %s" + % (" ".join(cmdline), e)) - @staticmethod - def call_capture_output(cmdline, cwd=None, error_on_nonzero=True): - """ - :returns: a tuple (return code, stdout_data, stderr_data). - """ + def call_capture_output(self, cmdline, cwd=None, error_on_nonzero=True): from subprocess import Popen, PIPE + try: - popen = Popen(cmdline, cwd=cwd, stdin=PIPE, stdout=PIPE, stderr=PIPE) + popen = Popen(cmdline, cwd=cwd, stdin=PIPE, stdout=PIPE, + stderr=PIPE) stdout_data, stderr_data = popen.communicate() + if error_on_nonzero and popen.returncode: raise ExecError("status %d invoking '%s': %s" - % (popen.returncode, " ".join(cmdline), stderr_data)) + % (popen.returncode, " ".join(cmdline), + stderr_data)) + return popen.returncode, stdout_data, stderr_data except OSError as e: raise ExecError("error invoking '%s': %s" - % ( " ".join(cmdline), e)) + % (" ".join(cmdline), e)) + + def wait(self, aid): + proc = self.apids.pop(aid) + retc = proc.wait() + return retc + + def waitall(self): + return {aid: self.wait(aid) for aid in list(self.apids)} def _send_packet(sock, data): @@ -64,6 +77,7 @@ def _send_packet(sock, data): sock.sendall(pack("I", len(packet))) sock.sendall(packet) + def _recv_packet(sock, who="Process", partner="other end"): from struct import calcsize, unpack size_bytes_size = calcsize("I") @@ -85,36 +99,39 @@ def _recv_packet(sock, who="Process", partner="other end"): return loads(packet) - - def _fork_server(sock): + # Ignore keyboard interrupts, we'll get notified by the parent. import signal - # ignore keyboard interrupts, we'll get notified by the parent. signal.signal(signal.SIGINT, signal.SIG_IGN) - quitflag = [False] - - def quit(): - quitflag[0] = True + # Construct a local DirectForker to do the dirty work + df = DirectForker() funcs = { - "quit": quit, - "call": DirectForker.call, - "call_capture_stdout": DirectForker.call_capture_stdout, - "call_capture_output": DirectForker.call_capture_output, + "call": df.call, + "call_async": df.call_async, + "call_capture_output": df.call_capture_output, + "wait": df.wait, + "waitall": df.waitall } try: - while not quitflag[0]: - func_name, args, kwargs = _recv_packet(sock, - who="Prefork server", partner="parent") - - try: - result = funcs[func_name](*args, **kwargs) - except Exception as e: - _send_packet(sock, ("exception", e)) + while True: + func_name, args, kwargs = _recv_packet( + sock, who="Prefork server", partner="parent" + ) + + if func_name == "quit": + df.waitall() + _send_packet(sock, ("ok", None)) + break else: - _send_packet(sock, ("ok", result)) + try: + result = funcs[func_name](*args, **kwargs) + except Exception as e: + _send_packet(sock, ("exception", e)) + else: + _send_packet(sock, ("ok", result)) finally: sock.close() @@ -122,18 +139,19 @@ def _fork_server(sock): os._exit(0) - - - -class IndirectForker: +class IndirectForker(object): def __init__(self, server_pid, sock): self.server_pid = server_pid self.socket = sock + import atexit + atexit.register(self._quit) + def _remote_invoke(self, name, *args, **kwargs): _send_packet(self.socket, (name, args, kwargs)) - status, result = _recv_packet(self.socket, - who="Prefork client", partner="prefork server") + status, result = _recv_packet( + self.socket, who="Prefork client", partner="prefork server" + ) if status == "exception": raise result @@ -142,24 +160,31 @@ class IndirectForker: def _quit(self): self._remote_invoke("quit") + from os import waitpid waitpid(self.server_pid, 0) def call(self, cmdline, cwd=None): return self._remote_invoke("call", cmdline, cwd) - def call_capture_stdout(self, cmdline, cwd=None): - return self._remote_invoke("call_capture_stdout", cmdline, cwd) + def call_async(self, cmdline, cwd=None): + return self._remote_invoke("call_async", cmdline, cwd) def call_capture_output(self, cmdline, cwd=None, error_on_nonzero=True): - return self._remote_invoke("call_capture_output", cmdline, cwd, - error_on_nonzero) + return self._remote_invoke("call_capture_output", cmdline, cwd, + error_on_nonzero) + def wait(self, aid): + return self._remote_invoke("wait", aid) + def waitall(self): + return self._remote_invoke("waitall") def enable_prefork(): - if isinstance(forker[0], IndirectForker): + global forker + + if isinstance(forker, IndirectForker): return from socket import socketpair @@ -168,30 +193,34 @@ def enable_prefork(): from os import fork fork_res = fork() + # Child if fork_res == 0: - # child s_parent.close() _fork_server(s_child) + # Parent else: s_child.close() - forker[0] = IndirectForker(fork_res, s_parent) + forker = IndirectForker(fork_res, s_parent) - import atexit - atexit.register(forker[0]._quit) +forker = DirectForker() +def call(cmdline, cwd=None): + return forker.call(cmdline, cwd) -forker = [DirectForker()] -def call(cmdline, cwd=None): - return forker[0].call(cmdline, cwd) +def call_async(cmdline, cwd=None): + return forker.call_async(cmdline, cwd) -def call_capture_stdout(cmdline, cwd=None): - from warnings import warn - warn("call_capture_stdout is deprecated: use call_capture_output instead", - stacklevel=2) - return forker[0].call_capture_stdout(cmdline, cwd) def call_capture_output(cmdline, cwd=None, error_on_nonzero=True): - return forker[0].call_capture_output(cmdline, cwd, error_on_nonzero) + return forker.call_capture_output(cmdline, cwd, error_on_nonzero) + + +def wait(aid): + return forker.wait(aid) + + +def waitall(): + return forker.waitall()