diff --git a/pytools/prefork.py b/pytools/prefork.py index 93a56d7e03aad6eea44c8a9046e3beecb51c9cad..9cb90541ccc9a0ba194eb73cec87fc5cbef1aeeb 100644 --- a/pytools/prefork.py +++ b/pytools/prefork.py @@ -13,18 +13,36 @@ class ExecError(OSError): class DirectForker(object): - @staticmethod - def call(cmdline, cwd=None): + def __init__(self): + self.apids = {} + self.count = 0 + + def call(self, cmdline, cwd=None): from subprocess import call + try: return call(cmdline, cwd=cwd) except OSError as e: raise ExecError("error invoking '%s': %s" % (" ".join(cmdline), e)) - @staticmethod - def call_capture_output(cmdline, cwd=None, error_on_nonzero=True): + def call_async(self, cmdline, cwd=None): + from subprocess import Popen + + try: + self.count += 1 + + proc = Popen(cmdline, cwd=cwd) + self.apids[self.count] = proc + + return self.count + except OSError as e: + raise ExecError("error invoking '%s': %s" + % (" ".join(cmdline), e)) + + def call_capture_output(self, cmdline, cwd=None, error_on_nonzero=True): from subprocess import Popen, PIPE + try: popen = Popen(cmdline, cwd=cwd, stdin=PIPE, stdout=PIPE, stderr=PIPE) @@ -40,6 +58,15 @@ class DirectForker(object): raise ExecError("error invoking '%s': %s" % (" ".join(cmdline), e)) + def wait(self, aid): + proc = self.apids.pop(aid) + retc = proc.wait() + + return retc + + def waitall(self): + return {aid: self.wait(aid) for aid in list(self.apids)} + def _send_packet(sock, data): from struct import pack @@ -73,33 +100,38 @@ def _recv_packet(sock, who="Process", partner="other end"): def _fork_server(sock): + # Ignore keyboard interrupts, we'll get notified by the parent. import signal - # ignore keyboard interrupts, we'll get notified by the parent. signal.signal(signal.SIGINT, signal.SIG_IGN) - quitflag = [False] - - def quit(): - quitflag[0] = True + # Construct a local DirectForker to do the dirty work + df = DirectForker() funcs = { - "quit": quit, - "call": DirectForker.call, - "call_capture_output": DirectForker.call_capture_output, + "call": df.call, + "call_async": df.call_async, + "call_capture_output": df.call_capture_output, + "wait": df.wait, + "waitall": df.waitall } try: - while not quitflag[0]: + while True: func_name, args, kwargs = _recv_packet( sock, who="Prefork server", partner="parent" ) - try: - result = funcs[func_name](*args, **kwargs) - except Exception as e: - _send_packet(sock, ("exception", e)) + if func_name == "quit": + df.waitall() + _send_packet(sock, ("ok", None)) + break else: - _send_packet(sock, ("ok", result)) + try: + result = funcs[func_name](*args, **kwargs) + except Exception as e: + _send_packet(sock, ("exception", e)) + else: + _send_packet(sock, ("ok", result)) finally: sock.close() @@ -112,6 +144,9 @@ class IndirectForker(object): self.server_pid = server_pid self.socket = sock + import atexit + atexit.register(self._quit) + def _remote_invoke(self, name, *args, **kwargs): _send_packet(self.socket, (name, args, kwargs)) status, result = _recv_packet( @@ -132,10 +167,19 @@ class IndirectForker(object): def call(self, cmdline, cwd=None): return self._remote_invoke("call", cmdline, cwd) + def call_async(self, cmdline, cwd=None): + return self._remote_invoke("call_async", cmdline, cwd) + def call_capture_output(self, cmdline, cwd=None, error_on_nonzero=True): return self._remote_invoke("call_capture_output", cmdline, cwd, error_on_nonzero) + def wait(self, aid): + return self._remote_invoke("wait", aid) + + def waitall(self): + return self._remote_invoke("waitall") + def enable_prefork(): global forker @@ -158,9 +202,6 @@ def enable_prefork(): s_child.close() forker = IndirectForker(fork_res, s_parent) - import atexit - atexit.register(forker._quit) - forker = DirectForker() @@ -169,5 +210,17 @@ def call(cmdline, cwd=None): return forker.call(cmdline, cwd) +def call_async(cmdline, cwd=None): + return forker.call_async(cmdline, cwd) + + def call_capture_output(cmdline, cwd=None, error_on_nonzero=True): return forker.call_capture_output(cmdline, cwd, error_on_nonzero) + + +def wait(aid): + return forker.wait(aid) + + +def waitall(): + return forker.waitall()