Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
aos
GitHub Repository: aos/firecracker
Path: blob/main/tests/framework/scheduler.py
1956 views
1
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
# SPDX-License-Identifier: Apache-2.0
3
"""Pytest plugin that schedules tests to run concurrently.
4
5
This plugin adds a new command line option (`--concurrency`), allowing the
6
user to choose the maximum number of worker processes that can run tests
7
concurrently.
8
9
Tests are split into batches, each batch being assigned a maximum concurrency
10
level. For instance, all performance tests will run sequentially
11
(i.e. concurrency=1), since they rely on the availability of the full host
12
resources, in order to make accurate measurements. Additionally, other tests
13
may be restricted to running sequentially, if they are per se
14
concurrency-unsafe. See `PytestScheduler.pytest_runtestloop()`.
15
16
Scheduling is achieved by overriding the pytest run loop (i.e.
17
`pytest_runtestloop()`), and splitting the test session item list across
18
multiple `fork()`ed worker processes. Since no user code is run before
19
`pytest_runtestloop()`, each worker becomes a pytest session itself.
20
Reporting is disabled for worker process, each worker sending its results
21
back to the main / server process, via an IPC pipe, for aggregation.
22
"""
23
24
import multiprocessing as mp
25
import os
26
import re
27
import sys
28
from random import random
29
from select import select
30
from time import sleep
31
import pytest
32
# Needed to force deselect nonci tests.
33
from _pytest.mark import Expression, MarkMatcher
34
from _pytest.main import ExitCode
35
36
from . import defs # pylint: disable=relative-beyond-top-level
37
from . import mpsing # pylint: disable=relative-beyond-top-level
38
from . import report as treport # pylint: disable=relative-beyond-top-level
39
40
41
class PytestScheduler(mpsing.MultiprocessSingleton):
42
"""A pretty custom test execution scheduler."""
43
44
def __init__(self):
45
"""Initialize the scheduler.
46
47
Not to be called directly, since this is a singleton. Use
48
`PytestScheduler.instance()` to get the scheduler object.
49
"""
50
super().__init__()
51
self._mp_singletons = [self]
52
self.session = None
53
54
# Initialize a report for this session
55
self._report = treport.Report(defs.TEST_RESULTS_DIR / "report")
56
57
def register_mp_singleton(self, mp_singleton):
58
"""Register a multi-process singleton object.
59
60
Since the scheduler will be handling the main testing loop, it needs
61
to be aware of any multi-process singletons that must be serviced
62
during the test run (i.e. polled and allowed to handle method
63
execution in the server context).
64
"""
65
self._mp_singletons.append(mp_singleton)
66
67
@staticmethod
68
def do_pytest_addoption(parser):
69
"""Pytest hook. Add concurrency command line option."""
70
avail_cpus = len(os.sched_getaffinity(0))
71
# Defaulting to a third of the available (logical) CPUs sounds like a
72
# good enough plan.
73
default = max(1, int(avail_cpus / 3))
74
parser.addoption(
75
"--concurrency",
76
"--concurrency",
77
dest="concurrency",
78
action="store",
79
type=int,
80
default=default,
81
help="Concurrency level (max number of worker processes to spawn)."
82
)
83
84
def pytest_sessionstart(self, session):
85
"""Pytest hook. Called at pytest session start.
86
87
This will execute in the server context (before the tests are
88
executed).
89
"""
90
self.session = session
91
92
def pytest_runtest_logreport(self, report):
93
"""Pytest hook. Called whenever a new test report is ready.
94
95
This will execute in the worker / child context.
96
"""
97
self._add_report(report)
98
99
# Mark it as finished after the call
100
if report.when == "call":
101
self._report.finish_test_item(report)
102
103
def pytest_report_collectionfinish(self, items, *_):
104
"""Pytest hook. Called after collecting all the tests."""
105
self._report.add_collected_items(items)
106
107
@staticmethod
108
def filter_batch(config, batch, marker_name):
109
"""Deselect marked tests which are not explicitly selected."""
110
deselected = []
111
expr = Expression.compile(config.option.markexpr)
112
for item in batch['items'][:]:
113
for key in item.keywords:
114
if key is marker_name and \
115
not expr.evaluate(MarkMatcher.from_item(item)):
116
deselected.append(item)
117
batch['items'].remove(item)
118
break
119
120
config.hook.pytest_deselected(items=deselected)
121
122
def pytest_pyfunc_call(self, pyfuncitem):
123
"""Pytest hook. Called before executing a test."""
124
# Overwrite the function with a custom one to catch return values.
125
# These are used to catch expected vs. actual values and used in
126
# reporting
127
self._report.catch_return(pyfuncitem)
128
129
def pytest_runtestloop(self, session):
130
"""Pytest hook. The main test scheduling and running loop.
131
132
Called in the server process context.
133
"""
134
# Don't run tests on test discovery
135
if session.config.option.collectonly:
136
return True
137
138
# max_concurrency = self.session.config.option.concurrency
139
schedule = [
140
{
141
# Performance batch: tests that measure performance, and need
142
# to be run in a non-cuncurrent environment.
143
'name': 'performance',
144
'concurrency': 1,
145
'patterns': [
146
"/performance/.+",
147
],
148
'items': []
149
},
150
{
151
# Unsafe batch: tests that, for any reason, are not
152
# concurrency-safe, and therefore need to be run sequentially.
153
'name': 'unsafe',
154
'concurrency': 1,
155
'patterns': [
156
"/functional/test_initrd.py",
157
"/functional/test_max_vcpus.py",
158
"/functional/test_rate_limiter.py",
159
"/functional/test_signals.py",
160
"/build/test_coverage.py"
161
],
162
'items': []
163
},
164
{
165
# Safe batch: tests that can be run safely in a concurrent
166
# environment.
167
'name': 'safe',
168
# FIXME: we still have some framework concurrency issues
169
# which prevent us from successfully using `max_concurrency`.
170
# 'concurrency': max_concurrency,
171
'concurrency': 1,
172
'patterns': [
173
"/functional/.+",
174
"/build/.+",
175
"/security/.+"
176
],
177
'items': []
178
},
179
{
180
# Unknown batch: a catch-all batch, scheduling any tests that
181
# haven't been categorized to run sequentially (since we don't
182
# know if they are concurrency-safe).
183
'name': 'unknown',
184
'concurrency': 1,
185
'patterns': [".+"],
186
'items': []
187
}
188
]
189
190
# Go through the list of tests and assign each of them to its
191
# corresponding batch in the schedule.
192
for item in session.items:
193
# A test can match any of the patterns defined by the batch,
194
# in order to get assigned to it.
195
for batch in schedule:
196
# Found a matching batch. No need to look any further.
197
if re.search(
198
"|".join(["({})".format(x) for x in batch['patterns']]),
199
"/".join(item.listnames()),
200
) is not None:
201
batch['items'].append(item)
202
break
203
204
# Filter out empty batches.
205
schedule = [batch for batch in schedule if batch['items']]
206
207
# Evaluate marker expression only for the marked batch items.
208
# If pytest runs with a marker expression which does not include
209
# `nonci` marked tests (e.g `-m "not nonci" or non-existent marker
210
# expression), the tests marked with `nonci` marker will be skipped.
211
for batch in schedule:
212
PytestScheduler.filter_batch(session.config,
213
batch,
214
marker_name="nonci")
215
break
216
217
for batch in schedule:
218
self._raw_stdout(
219
"\n[ ",
220
self._colorize('yellow', batch['name']),
221
" | ",
222
"{} tests".format(len(batch['items'])),
223
" | ",
224
"{} worker(s)".format(batch['concurrency']),
225
" ]\n"
226
)
227
self._run_batch(batch)
228
229
return "stahp"
230
231
@pytest.mark.tryfirst
232
# pylint: disable=unused-argument
233
# pylint: disable=no-self-use
234
def pytest_sessionfinish(self, session, exitstatus):
235
"""Pytest hook. Wrap up the whole testing session.
236
237
Since the scheduler is more or less mangling the test session in order
238
to distribute test items to worker processes, the main pytest process
239
can become unaware of test failures and errors. Using this session
240
wrap-up hook to set the correct exit code.
241
"""
242
trep = session.config.pluginmanager.getplugin("terminalreporter")
243
if "error" in trep.stats:
244
session.exitstatus = ExitCode.INTERNAL_ERROR
245
if "failed" in trep.stats:
246
session.exitstatus = ExitCode.TESTS_FAILED
247
248
def _run_batch(self, batch):
249
"""Run the tests in this batch, spread across multiple workers.
250
251
Called in the server process context.
252
"""
253
max_workers = batch['concurrency']
254
items_per_worker = max(1, int(len(batch['items']) / max_workers))
255
workers = []
256
while batch['items']:
257
# Pop `items_per_worker` out from this batch and send them to
258
# a new worker.
259
worker_items = batch['items'][-items_per_worker:]
260
del batch['items'][-items_per_worker:]
261
262
# Avoid storming the host with too many workers started at once.
263
_delay = random() + len(workers) / 5.0 if max_workers > 1 else 0
264
265
# Create the worker process and start it up.
266
worker = mp.Process(
267
target=self._worker_main,
268
args=(worker_items, _delay)
269
)
270
workers.append(worker)
271
worker.start()
272
273
# Main loop, reaping workers and processing IPC requests.
274
while workers:
275
rlist, _, _ = select(self._mp_singletons, [], [], 0.1)
276
for mps in rlist:
277
mps.handle_ipc_call()
278
_ = [w.join() for w in workers if not w.is_alive()]
279
workers = [w for w in workers if w.is_alive()]
280
281
def _worker_main(self, items, startup_delay=0):
282
"""Execute a bunch of test items sequentially.
283
284
This is the worker process entry point and main loop.
285
"""
286
sys.stdin.close()
287
# Sleeping here to avoid storming the host when many workers are
288
# started at the same time.
289
#
290
# TODO: investigate storming issue;
291
# Not sure what the exact problem is, but worker storms cause an
292
# elevated response time on the API socket. Since the reponse
293
# time is measured by our decorators, it also includes the
294
# Python libraries overhead, which might be non-negligible.
295
sleep(startup_delay if startup_delay else 0)
296
297
# Restrict the session to this worker's item list only.
298
# I.e. make pytest believe that the test session is limited to this
299
# worker's job.
300
self.session.items = items
301
302
# Disable the terminal reporting plugin, so it doesn't mess up
303
# stdout, when run in a multi-process context.
304
# The terminal reporter plugin will remain enabled in the server
305
# process, gathering data via worker calls to `_add_report()`.
306
trep = self.session.config.pluginmanager.get_plugin("terminalreporter")
307
self.session.config.pluginmanager.unregister(trep)
308
309
for item, nextitem in zip(
310
self.session.items,
311
self.session.items[1:] + [None]
312
):
313
item.ihook.pytest_runtest_protocol(item=item, nextitem=nextitem)
314
315
@mpsing.ipcmethod
316
def _add_report(self, report):
317
"""Send a test report to the server process.
318
319
A report is generated for every test item, and for every test phase
320
(setup, call, and teardown).
321
"""
322
# Translation matrix from (when)x(outcome) to pytest's
323
# terminalreporter plugin stats (dictionary) key.
324
key_xlat = {
325
"setup.passed": "",
326
"setup.failed": "error",
327
"setup.skipped": "skipped",
328
"call.passed": "passed",
329
"call.failed": "failed",
330
"call.skipped": "skipped",
331
"teardown.passed": "",
332
"teardown.failed": "error",
333
"teardown.skipped": ""
334
}
335
stats_key = key_xlat["{}.{}".format(report.when, report.outcome)]
336
337
trep = self.session.config.pluginmanager.get_plugin("terminalreporter")
338
if trep:
339
if stats_key not in trep.stats:
340
trep.stats[stats_key] = []
341
trep.stats[stats_key].append(report)
342
343
if stats_key:
344
self._report_progress(report.nodeid, stats_key)
345
346
def _report_progress(self, nodeid, outcome):
347
"""Show the user some nice progress indication."""
348
outcome_cols = {
349
"passed": "green",
350
"failed": "red",
351
"error": "red",
352
"skipped": "yellow"
353
}
354
if outcome not in outcome_cols:
355
return
356
357
color = outcome_cols[outcome]
358
self._raw_stdout(
359
" ",
360
self._colorize(color, "{:10}".format(outcome.upper())),
361
self._colorize(color, nodeid)
362
if outcome in ["error", "failed"]
363
else nodeid,
364
"\n"
365
)
366
367
@staticmethod
368
def _colorize(color, msg):
369
"""Add an ANSI / terminal color escape code to `msg`.
370
371
If stdout is not a terminal, `msg` will just be encoded into a byte
372
stream, without adding any ANSI decorations.
373
Note: the returned value will always be a stream of bytes, not a
374
string, since the result needs to be sent straight to the
375
terminal.
376
"""
377
if not isinstance(msg, bytes):
378
msg = str(msg).encode("utf-8")
379
if not sys.stdout.isatty():
380
return msg
381
term_codes = {
382
'red': b"\x1b[31m",
383
'yellow': b"\x1b[33m",
384
'green': b"\x1b[32m",
385
'reset': b"\x1b(B\x1b[m"
386
}
387
return term_codes[color] + msg + term_codes['reset']
388
389
@staticmethod
390
def _raw_stdout(*words):
391
"""Send raw-byte output to stdout.
392
393
All arguments are concatenated and, if necessary, encoded into raw
394
byte streams, before being written to stdout.
395
"""
396
byte_words = [
397
w if isinstance(w, bytes) else str(w).encode("utf-8")
398
for w in words
399
]
400
buf = b"".join(byte_words)
401
os.write(sys.stdout.fileno(), buf)
402
403