r"""
Utilities
"""
import errno
import os
import platform
import traceback
from typing import Optional
class RemoteException(Exception):
"""
Raised if an exception occurred in one of the child processes.
"""
tb: str
def __init__(self, tb: str):
"""
Initialize the exception.
INPUT:
- ``tb`` -- the traceback of the exception.
"""
self.tb = tb
def __str__(self):
"""
Return a string representation of the exception.
"""
return self.tb
class RemoteExceptionWrapper:
"""
Used by child processes to capture exceptions thrown during execution and
report them to the main process, including the correct traceback.
"""
exc: BaseException
tb: str
def __init__(self, exc: BaseException):
"""
Initialize the exception wrapper.
INPUT:
- ``exc`` -- the exception to wrap.
"""
self.exc = exc
self.tb = traceback.format_exc()
tb = traceback.format_exception(type(exc), exc, exc.__traceback__)
tb = "".join(tb)
self.exc = exc
self.tb = f'\n"""\n{tb}"""'
@staticmethod
def _rebuild_exc(exc: BaseException, tb: str):
"""
Reconstructs the exception, putting the original exception as cause.
"""
exc.__cause__ = RemoteException(tb)
return exc
def __reduce__(self):
"""
TESTS::
sage: import pickle
sage: from sage_docbuild.utils import RemoteExceptionWrapper
sage: pickle.dumps(RemoteExceptionWrapper(ZeroDivisionError()), 0).decode()
...RemoteExceptionWrapper...ZeroDivisionError...
"""
return RemoteExceptionWrapper._rebuild_exc, (self.exc, self.tb)
class WorkerDiedException(RuntimeError):
"""Raised if a worker process dies unexpected."""
original_exception: Optional[BaseException]
def __init__(
self, message: Optional[str], original_exception: Optional[BaseException] = None
):
super().__init__(message)
self.original_exception = original_exception
def build_many(target, args, processes=None):
"""
Map a list of arguments in ``args`` to a single-argument target function
``target`` in parallel using ``multiprocessing.cpu_count()`` (or
``processes`` if given) simultaneous processes.
This is a simplified version of ``multiprocessing.Pool.map`` from the
Python standard library which avoids a couple of its pitfalls. In
particular, it can abort (with a :class:`RuntimeError`)
without hanging if one of
the worker processes unexpectedly dies. It also has semantics equivalent
to ``maxtasksperchild=1``; that is, one process is started per argument.
As such, this is inefficient for processing large numbers of fast tasks,
but appropriate for running longer tasks (such as doc builds) which may
also require significant cleanup.
It also avoids starting new processes from a pthread, which results in at
least one known issue:
* When PARI is built with multi-threading support, forking a Sage
process from a thread leaves the main Pari interface instance broken
(see :issue:`26608#comment:38`).
In the future this may be replaced by a generalized version of the more
robust parallel processing implementation from ``sage.doctest.forker``.
EXAMPLES::
sage: from sage_docbuild.utils import build_many
sage: def target(N):
....: import time
....: time.sleep(float(0.1))
....: print('Processed task %s' % N)
sage: _ = build_many(target, range(8), processes=8)
Processed task ...
Processed task ...
Processed task ...
Processed task ...
Processed task ...
Processed task ...
Processed task ...
Processed task ...
This version can also return a result, and thus can
be used as a replacement for ``multiprocessing.Pool.map`` (i.e. it still
blocks until the result is ready)::
sage: def square(N):
....: return N * N
sage: build_many(square, range(100))
[0, 1, 4, 9, ..., 9604, 9801]
If the target function raises an exception in any of the workers,
``build_many`` raises that exception and all other results are discarded.
Any in-progress tasks may still be allowed to complete gracefully before
the exception is raised::
sage: def target(N):
....: import time, os, signal
....: if N == 4:
....: # Task 4 is a poison pill
....: 1 / 0
....: else:
....: time.sleep(float(0.5))
....: print('Processed task %s' % N)
Note: In practice this test might still show output from the other worker
processes before the poison-pill is executed. It may also display the
traceback from the failing process on stderr. However, due to how the
doctest runner works, the doctest will only expect the final exception::
sage: build_many(target, range(8), processes=8)
Traceback (most recent call last):
...
raise ZeroDivisionError("rational division by zero")
ZeroDivisionError: rational division by zero
...
raise worker_exc.original_exception
ZeroDivisionError: rational division by zero
Similarly, if one of the worker processes dies unexpectedly otherwise exits
non-zero (e.g. killed by a signal) any in-progress tasks will be completed
gracefully, but then a :class:`RuntimeError` is raised and pending tasks
are not started::
sage: def target(N):
....: import time, os, signal
....: if N == 4:
....: # Task 4 is a poison pill
....: os.kill(os.getpid(), signal.SIGKILL)
....: else:
....: time.sleep(float(0.5))
....: print('Processed task %s' % N)
sage: build_many(target, range(8), processes=8)
Traceback (most recent call last):
...
WorkerDiedException: worker for 4 died with non-zero exit code -9
"""
from multiprocessing import Process, Queue, cpu_count, set_start_method
if platform.system() == "Darwin":
set_start_method("fork", force=True)
from queue import Empty
if processes is None:
processes = cpu_count()
workers = [None] * processes
tasks = enumerate(args)
results = []
result_queue = Queue()
def run_worker(target, queue, idx, task):
try:
result = target(task)
except BaseException as exc:
queue.put((None, RemoteExceptionWrapper(exc)))
else:
queue.put((idx, result))
def bring_out_yer_dead(w, task, exitcode):
"""
Handle a dead / completed worker. Raises WorkerDiedException if it
returned with a non-zero exit code.
"""
if w is None or exitcode is None:
return (w, task)
if w._popen.returncode is None:
w._popen.returncode = exitcode
if exitcode != 0:
raise WorkerDiedException(
f"worker for {task[1]} died with non-zero exit code {w.exitcode}"
)
try:
result = result_queue.get_nowait()
if result[0] is None:
exception = result[1]
raise WorkerDiedException("", original_exception=exception)
else:
results.append(result)
except Empty:
pass
w.join()
return None
def wait_for_one():
"""Wait for a single process and return its pid and exit code."""
try:
pid, sts = os.wait()
except OSError as exc:
if exc.errno != errno.ECHILD:
raise
else:
return None, None
if os.WIFSIGNALED(sts):
exitcode = -os.WTERMSIG(sts)
else:
exitcode = os.WEXITSTATUS(sts)
return pid, exitcode
def reap_workers(waited_pid=None, waited_exitcode=None):
"""
This is the main worker handling loop.
Checks if workers have completed their tasks and spawns new workers if
there are more tasks on the queue. Returns `False` if there is more
work to be done or `True` if the work is complete.
Raises a ``WorkerDiedException`` if a worker exits unexpectedly.
"""
all_done = True
for idx, w in enumerate(workers):
if w is not None:
w, task = w
if w.pid == waited_pid:
exitcode = waited_exitcode
else:
exitcode = w.exitcode
w = bring_out_yer_dead(w, task, exitcode)
if w is None:
try:
task = next(tasks)
except StopIteration:
pass
else:
w = Process(target=run_worker, args=((target, result_queue) + task))
w.start()
w = (w, task)
workers[idx] = w
if w is not None:
all_done = False
return all_done
waited_pid = None
waited_exitcode = None
worker_exc = None
try:
while True:
try:
if reap_workers(waited_pid, waited_exitcode):
break
except WorkerDiedException as exc:
worker_exc = exc
break
waited_pid, waited_exitcode = wait_for_one()
finally:
try:
remaining_workers = [w for w in workers if w is not None]
for w, _ in remaining_workers:
try:
w.terminate()
except OSError as exc:
if exc.errno != errno.ESRCH:
raise
for w, _ in remaining_workers:
w.join()
finally:
if worker_exc is not None:
if worker_exc.original_exception is not None:
raise worker_exc.original_exception
else:
raise worker_exc
while True:
try:
results.append(result_queue.get_nowait())
except Empty:
break
return [r[1] for r in sorted(results, key=lambda r: r[0])]