Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
torvalds
GitHub Repository: torvalds/linux
Path: blob/master/tools/lib/python/unittest_helper.py
170891 views
1
#!/usr/bin/env python3
2
# SPDX-License-Identifier: GPL-2.0
3
# Copyright(c) 2025-2026: Mauro Carvalho Chehab <[email protected]>.
4
#
5
# pylint: disable=C0103,R0912,R0914,E1101
6
7
"""
8
Provides helper functions and classes execute python unit tests.
9
10
Those help functions provide a nice colored output summary of each
11
executed test and, when a test fails, it shows the different in diff
12
format when running in verbose mode, like::
13
14
$ tools/unittests/nested_match.py -v
15
...
16
Traceback (most recent call last):
17
File "/new_devel/docs/tools/unittests/nested_match.py", line 69, in test_count_limit
18
self.assertEqual(replaced, "bar(a); bar(b); foo(c)")
19
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20
AssertionError: 'bar(a) foo(b); foo(c)' != 'bar(a); bar(b); foo(c)'
21
- bar(a) foo(b); foo(c)
22
? ^^^^
23
+ bar(a); bar(b); foo(c)
24
? ^^^^^
25
...
26
27
It also allows filtering what tests will be executed via ``-k`` parameter.
28
29
Typical usage is to do::
30
31
from unittest_helper import run_unittest
32
...
33
34
if __name__ == "__main__":
35
run_unittest(__file__)
36
37
If passing arguments is needed, on a more complex scenario, it can be
38
used like on this example::
39
40
from unittest_helper import TestUnits, run_unittest
41
...
42
env = {'sudo': ""}
43
...
44
if __name__ == "__main__":
45
runner = TestUnits()
46
base_parser = runner.parse_args()
47
base_parser.add_argument('--sudo', action='store_true',
48
help='Enable tests requiring sudo privileges')
49
50
args = base_parser.parse_args()
51
52
# Update module-level flag
53
if args.sudo:
54
env['sudo'] = "1"
55
56
# Run tests with customized arguments
57
runner.run(__file__, parser=base_parser, args=args, env=env)
58
"""
59
60
import argparse
61
import atexit
62
import os
63
import re
64
import unittest
65
import sys
66
67
from unittest.mock import patch
68
69
70
class Summary(unittest.TestResult):
71
"""
72
Overrides ``unittest.TestResult`` class to provide a nice colored
73
summary. When in verbose mode, displays actual/expected difference in
74
unified diff format.
75
"""
76
def __init__(self, *args, **kwargs):
77
super().__init__(*args, **kwargs)
78
79
#: Dictionary to store organized test results.
80
self.test_results = {}
81
82
#: max length of the test names.
83
self.max_name_length = 0
84
85
def startTest(self, test):
86
super().startTest(test)
87
test_id = test.id()
88
parts = test_id.split(".")
89
90
# Extract module, class, and method names
91
if len(parts) >= 3:
92
module_name = parts[-3]
93
else:
94
module_name = ""
95
if len(parts) >= 2:
96
class_name = parts[-2]
97
else:
98
class_name = ""
99
100
method_name = parts[-1]
101
102
# Build the hierarchical structure
103
if module_name not in self.test_results:
104
self.test_results[module_name] = {}
105
106
if class_name not in self.test_results[module_name]:
107
self.test_results[module_name][class_name] = []
108
109
# Track maximum test name length for alignment
110
display_name = f"{method_name}:"
111
112
self.max_name_length = max(len(display_name), self.max_name_length)
113
114
def _record_test(self, test, status):
115
test_id = test.id()
116
parts = test_id.split(".")
117
if len(parts) >= 3:
118
module_name = parts[-3]
119
else:
120
module_name = ""
121
if len(parts) >= 2:
122
class_name = parts[-2]
123
else:
124
class_name = ""
125
method_name = parts[-1]
126
self.test_results[module_name][class_name].append((method_name, status))
127
128
def addSuccess(self, test):
129
super().addSuccess(test)
130
self._record_test(test, "OK")
131
132
def addFailure(self, test, err):
133
super().addFailure(test, err)
134
self._record_test(test, "FAIL")
135
136
def addError(self, test, err):
137
super().addError(test, err)
138
self._record_test(test, "ERROR")
139
140
def addSkip(self, test, reason):
141
super().addSkip(test, reason)
142
self._record_test(test, f"SKIP ({reason})")
143
144
def printResults(self, verbose):
145
"""
146
Print results using colors if tty.
147
"""
148
# Check for ANSI color support
149
use_color = sys.stdout.isatty()
150
COLORS = {
151
"OK": "\033[32m", # Green
152
"FAIL": "\033[31m", # Red
153
"SKIP": "\033[1;33m", # Yellow
154
"PARTIAL": "\033[33m", # Orange
155
"EXPECTED_FAIL": "\033[36m", # Cyan
156
"reset": "\033[0m", # Reset to default terminal color
157
}
158
if not use_color:
159
for c in COLORS:
160
COLORS[c] = ""
161
162
# Calculate maximum test name length
163
if not self.test_results:
164
return
165
try:
166
lengths = []
167
for module in self.test_results.values():
168
for tests in module.values():
169
for test_name, _ in tests:
170
lengths.append(len(test_name) + 1) # +1 for colon
171
max_length = max(lengths) + 2 # Additional padding
172
except ValueError:
173
sys.exit("Test list is empty")
174
175
# Print results
176
for module_name, classes in self.test_results.items():
177
if verbose:
178
print(f"{module_name}:")
179
for class_name, tests in classes.items():
180
if verbose:
181
print(f" {class_name}:")
182
for test_name, status in tests:
183
if not verbose and status in [ "OK", "EXPECTED_FAIL" ]:
184
continue
185
186
# Get base status without reason for SKIP
187
if status.startswith("SKIP"):
188
status_code = status.split()[0]
189
else:
190
status_code = status
191
color = COLORS.get(status_code, "")
192
print(
193
f" {test_name + ':':<{max_length}}{color}{status}{COLORS['reset']}"
194
)
195
if verbose:
196
print()
197
198
# Print summary
199
print(f"\nRan {self.testsRun} tests", end="")
200
if hasattr(self, "timeTaken"):
201
print(f" in {self.timeTaken:.3f}s", end="")
202
print()
203
204
if not self.wasSuccessful():
205
print(f"\n{COLORS['FAIL']}FAILED (", end="")
206
failures = getattr(self, "failures", [])
207
errors = getattr(self, "errors", [])
208
if failures:
209
print(f"failures={len(failures)}", end="")
210
if errors:
211
if failures:
212
print(", ", end="")
213
print(f"errors={len(errors)}", end="")
214
print(f"){COLORS['reset']}")
215
216
217
def flatten_suite(suite):
218
"""Flatten test suite hierarchy."""
219
tests = []
220
for item in suite:
221
if isinstance(item, unittest.TestSuite):
222
tests.extend(flatten_suite(item))
223
else:
224
tests.append(item)
225
return tests
226
227
228
class TestUnits:
229
"""
230
Helper class to set verbosity level.
231
232
This class discover test files, import its unittest classes and
233
executes the test on it.
234
"""
235
def parse_args(self):
236
"""Returns a parser for command line arguments."""
237
parser = argparse.ArgumentParser(description="Test runner with regex filtering")
238
parser.add_argument("-v", "--verbose", action="count", default=1)
239
parser.add_argument("-q", "--quiet", action="store_true")
240
parser.add_argument("-f", "--failfast", action="store_true")
241
parser.add_argument("-k", "--keyword",
242
help="Regex pattern to filter test methods")
243
return parser
244
245
def run(self, caller_file=None, pattern=None,
246
suite=None, parser=None, args=None, env=None):
247
"""
248
Execute all tests from the unity test file.
249
250
It contains several optional parameters:
251
252
``caller_file``:
253
- name of the file that contains test.
254
255
typical usage is to place __file__ at the caller test, e.g.::
256
257
if __name__ == "__main__":
258
TestUnits().run(__file__)
259
260
``pattern``:
261
- optional pattern to match multiple file names. Defaults
262
to basename of ``caller_file``.
263
264
``suite``:
265
- an unittest suite initialized by the caller using
266
``unittest.TestLoader().discover()``.
267
268
``parser``:
269
- an argparse parser. If not defined, this helper will create
270
one.
271
272
``args``:
273
- an ``argparse.Namespace`` data filled by the caller.
274
275
``env``:
276
- environment variables that will be passed to the test suite
277
278
At least ``caller_file`` or ``suite`` must be used, otherwise a
279
``TypeError`` will be raised.
280
"""
281
if not args:
282
if not parser:
283
parser = self.parse_args()
284
args = parser.parse_args()
285
286
if not caller_file and not suite:
287
raise TypeError("Either caller_file or suite is needed at TestUnits")
288
289
if args.quiet:
290
verbose = 0
291
else:
292
verbose = args.verbose
293
294
if not env:
295
env = os.environ.copy()
296
297
env["VERBOSE"] = f"{verbose}"
298
299
patcher = patch.dict(os.environ, env)
300
patcher.start()
301
# ensure it gets stopped after
302
atexit.register(patcher.stop)
303
304
305
if verbose >= 2:
306
unittest.TextTestRunner(verbosity=verbose).run = lambda suite: suite
307
308
# Load ONLY tests from the calling file
309
if not suite:
310
if not pattern:
311
pattern = caller_file
312
313
loader = unittest.TestLoader()
314
suite = loader.discover(start_dir=os.path.dirname(caller_file),
315
pattern=os.path.basename(caller_file))
316
317
# Flatten the suite for environment injection
318
tests_to_inject = flatten_suite(suite)
319
320
# Filter tests by method name if -k specified
321
if args.keyword:
322
try:
323
pattern = re.compile(args.keyword)
324
filtered_suite = unittest.TestSuite()
325
for test in tests_to_inject: # Use the pre-flattened list
326
method_name = test.id().split(".")[-1]
327
if pattern.search(method_name):
328
filtered_suite.addTest(test)
329
suite = filtered_suite
330
except re.error as e:
331
sys.stderr.write(f"Invalid regex pattern: {e}\n")
332
sys.exit(1)
333
else:
334
# Maintain original suite structure if no keyword filtering
335
suite = unittest.TestSuite(tests_to_inject)
336
337
if verbose >= 2:
338
resultclass = None
339
else:
340
resultclass = Summary
341
342
runner = unittest.TextTestRunner(verbosity=args.verbose,
343
resultclass=resultclass,
344
failfast=args.failfast)
345
result = runner.run(suite)
346
if resultclass:
347
result.printResults(verbose)
348
349
sys.exit(not result.wasSuccessful())
350
351
352
def run_unittest(fname):
353
"""
354
Basic usage of TestUnits class.
355
356
Use it when there's no need to pass any extra argument to the tests
357
with. The recommended way is to place this at the end of each
358
unittest module::
359
360
if __name__ == "__main__":
361
run_unittest(__file__)
362
"""
363
TestUnits().run(fname)
364
365