Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aos
GitHub Repository: aos/firecracker
Path: blob/main/tests/framework/mpsing.py
1956 views
1
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
# SPDX-License-Identifier: Apache-2.0
3
"""A multi-process singleton implementation.
4
5
This module provides a facility for synchronization across a multi-process
6
pytest session, in the form of the `MultiprocessSingleton` class.
7
8
A `MultiprocessSingleton` object achieves cross-worker synchronization by
9
executing code in a single "server" process, under a lock. The process
10
where the singleton is initialized becomes the server process. Subsequently,
11
when called in the context of any child (i.e. "client") process, methods
12
marked with `@ipcmethod` are sent via an IPC pipe to the server process, for
13
execution. The result is returned to the caller via the same pipe.
14
15
`@ipcmethod` invokes are serialized via a lock, such that at any one time,
16
only one worker is executing code on its corresponding singleton server.
17
18
Restrictions:
19
- the singleton server must be initialized before any workers are
20
`fork()`ed;
21
- `@ipcmethod` arguments and results must be picklable, since they are
22
transmitted via an IPC pipe;
23
- the server process must poll the singleton for incoming execution
24
requests, and call its `handle_ipc_call()` method to handle them.
25
The singleton provides a pollable file descriptor via its `fileno()`
26
method.
27
"""
28
29
from multiprocessing import Pipe, Lock
30
31
32
def ipcmethod(fn):
33
"""Mark a singleton method to be executed in the server context.
34
35
A multi-process singleton implementor should use this decorator to mark
36
methods that should be executed in the server context, under the
37
singleton lock.
38
"""
39
def proxy_fn(inst, *args, **kwargs):
40
# pylint: disable=protected-access
41
return inst._ipc_call(fn.__name__, *args, **kwargs)
42
proxy_fn.orig_fn = fn
43
return proxy_fn
44
45
46
class SingletonReinitError(Exception):
47
"""Singleton reinitialization error."""
48
49
50
class MultiprocessSingleton:
51
"""A multi-process singleton (duh)."""
52
53
_instance = None
54
55
def __init__(self):
56
"""Docstring placeholder."""
57
if self.__class__._instance is not None:
58
raise SingletonReinitError()
59
self._mpsing_lock = Lock()
60
self._mpsing_server_conn, self._mpsing_client_conn = Pipe()
61
62
@classmethod
63
def instance(cls):
64
"""Return the local instance of this singleton."""
65
if cls._instance is None:
66
cls._instance = cls()
67
return cls._instance
68
69
def _ipc_call(self, fn_name, *args, **kwargs):
70
"""Peform the IPC call, from the client context.
71
72
This method is called in the client context. It sends an RPC request
73
to the server, and returns its result.
74
"""
75
if not callable(getattr(self, fn_name)):
76
raise TypeError(f"{fn_name} is not callable")
77
with self._mpsing_lock:
78
msg = (fn_name, args, kwargs)
79
self._mpsing_client_conn.send(msg)
80
result = self._mpsing_client_conn.recv()
81
if isinstance(result, BaseException):
82
# TODO: sending the exception through the IPC pipe will strip its
83
# __traceback__ property, as traceback objects cannot be
84
# pickled. It would be nice to send some kind of traceback
85
# info back though.
86
raise result
87
return result
88
89
def fileno(self):
90
"""Return a pollable IPC file descriptor.
91
92
The returned FD should be used to determine whether the server needs
93
to service any pending requests (i.e. when data is ready to be read
94
from the FD).
95
"""
96
return self._mpsing_server_conn.fileno()
97
98
def handle_ipc_call(self):
99
"""Handle the next IPC call from a client.
100
101
Called only in the server context, this method will perform a blocking
102
read from the IPC pipe. If the caller wants to avoid blocking here,
103
they should poll/select `self.fileno()` for reading before calling
104
this method.
105
"""
106
(fn_name, args, kwargs) = self._mpsing_server_conn.recv()
107
try:
108
res = getattr(self, fn_name).orig_fn(self, *args, **kwargs)
109
# pylint: disable=broad-except
110
except BaseException as exc:
111
res = exc
112
self._mpsing_server_conn.send(res)
113
114