CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/utils/general.py
Views: 475
1
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
"""
3
General utils
4
"""
5
6
import contextlib
7
import glob
8
import logging
9
import math
10
import os
11
import platform
12
import random
13
import re
14
import shutil
15
import signal
16
import time
17
import urllib
18
from itertools import repeat
19
from multiprocessing.pool import ThreadPool
20
from pathlib import Path
21
from subprocess import check_output
22
from zipfile import ZipFile
23
24
import cv2
25
import numpy as np
26
import pandas as pd
27
import pkg_resources as pkg
28
import torch
29
import torchvision
30
import yaml
31
32
from utils.downloads import gsutil_getsize
33
from utils.metrics import box_iou, fitness
34
pi = 3.141592
35
from utils.nms_rotated import obb_nms
36
37
# Settings
38
FILE = Path(__file__).resolve()
39
ROOT = FILE.parents[1] # YOLOv5 root directory
40
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
41
42
torch.set_printoptions(linewidth=320, precision=5, profile='long')
43
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
44
pd.options.display.max_columns = 10
45
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
46
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
47
48
49
def set_logging(name=None, verbose=True):
50
# Sets level and returns logger
51
for h in logging.root.handlers:
52
logging.root.removeHandler(h) # remove all handlers associated with the root logger object
53
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
54
logging.basicConfig(format="%(message)s", level=logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING)
55
return logging.getLogger(name)
56
57
58
LOGGER = set_logging(__name__) # define globally (used in train.py, val.py, detect.py, etc.)
59
60
61
class Profile(contextlib.ContextDecorator):
62
# Usage: @Profile() decorator or 'with Profile():' context manager
63
def __enter__(self):
64
self.start = time.time()
65
66
def __exit__(self, type, value, traceback):
67
print(f'Profile results: {time.time() - self.start:.5f}s')
68
69
70
class Timeout(contextlib.ContextDecorator):
71
# Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
72
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
73
self.seconds = int(seconds)
74
self.timeout_message = timeout_msg
75
self.suppress = bool(suppress_timeout_errors)
76
77
def _timeout_handler(self, signum, frame):
78
raise TimeoutError(self.timeout_message)
79
80
def __enter__(self):
81
signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
82
signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
83
84
def __exit__(self, exc_type, exc_val, exc_tb):
85
signal.alarm(0) # Cancel SIGALRM if it's scheduled
86
if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
87
return True
88
89
90
class WorkingDirectory(contextlib.ContextDecorator):
91
# Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
92
def __init__(self, new_dir):
93
self.dir = new_dir # new dir
94
self.cwd = Path.cwd().resolve() # current dir
95
96
def __enter__(self):
97
os.chdir(self.dir)
98
99
def __exit__(self, exc_type, exc_val, exc_tb):
100
os.chdir(self.cwd)
101
102
103
def try_except(func):
104
# try-except function. Usage: @try_except decorator
105
def handler(*args, **kwargs):
106
try:
107
func(*args, **kwargs)
108
except Exception as e:
109
print(e)
110
111
return handler
112
113
114
def methods(instance):
115
# Get class/instance methods
116
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
117
118
119
def print_args(name, opt):
120
# Print argparser arguments
121
LOGGER.info(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
122
123
124
def init_seeds(seed=0):
125
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
126
# cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
127
import torch.backends.cudnn as cudnn
128
random.seed(seed)
129
np.random.seed(seed)
130
torch.manual_seed(seed)
131
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
132
133
134
def intersect_dicts(da, db, exclude=()):
135
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
136
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
137
138
139
def get_latest_run(search_dir='.'):
140
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
141
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
142
return max(last_list, key=os.path.getctime) if last_list else ''
143
144
145
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
146
# Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
147
env = os.getenv(env_var)
148
if env:
149
path = Path(env) # use environment variable
150
else:
151
cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
152
path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
153
path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
154
path.mkdir(exist_ok=True) # make if required
155
return path
156
157
158
def is_writeable(dir, test=False):
159
# Return True if directory has write permissions, test opening a file with write permissions if test=True
160
if test: # method 1
161
file = Path(dir) / 'tmp.txt'
162
try:
163
with open(file, 'w'): # open file with write permissions
164
pass
165
file.unlink() # remove file
166
return True
167
except OSError:
168
return False
169
else: # method 2
170
return os.access(dir, os.R_OK) # possible issues on Windows
171
172
173
def is_docker():
174
# Is environment a Docker container?
175
return Path('/workspace').exists() # or Path('/.dockerenv').exists()
176
177
178
def is_colab():
179
# Is environment a Google Colab instance?
180
try:
181
import google.colab
182
return True
183
except ImportError:
184
return False
185
186
187
def is_pip():
188
# Is file in a pip package?
189
return 'site-packages' in Path(__file__).resolve().parts
190
191
192
def is_ascii(s=''):
193
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
194
s = str(s) # convert list, tuple, None, etc. to str
195
return len(s.encode().decode('ascii', 'ignore')) == len(s)
196
197
198
def is_chinese(s='人工智能'):
199
# Is string composed of any Chinese characters?
200
return re.search('[\u4e00-\u9fff]', s)
201
202
203
def emojis(str=''):
204
# Return platform-dependent emoji-safe version of string
205
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
206
207
208
def file_size(path):
209
# Return file/dir size (MB)
210
path = Path(path)
211
if path.is_file():
212
return path.stat().st_size / 1E6
213
elif path.is_dir():
214
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
215
else:
216
return 0.0
217
218
219
def check_online():
220
# Check internet connectivity
221
import socket
222
try:
223
socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
224
return True
225
except OSError:
226
return False
227
228
229
@try_except
230
@WorkingDirectory(ROOT)
231
def check_git_status():
232
# Recommend 'git pull' if code is out of date
233
msg = ', for updates see https://github.com/ultralytics/yolov5'
234
print(colorstr('github: '), end='')
235
assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
236
assert not is_docker(), 'skipping check (Docker image)' + msg
237
assert check_online(), 'skipping check (offline)' + msg
238
239
cmd = 'git fetch && git config --get remote.origin.url'
240
url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
241
branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
242
n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
243
if n > 0:
244
s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
245
else:
246
s = f'up to date with {url} ✅'
247
print(emojis(s)) # emoji-safe
248
249
250
def check_python(minimum='3.6.2'):
251
# Check current python version vs. required python version
252
check_version(platform.python_version(), minimum, name='Python ', hard=True)
253
254
255
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
256
# Check version vs. required version
257
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
258
result = (current == minimum) if pinned else (current >= minimum) # bool
259
s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
260
if hard:
261
assert result, s # assert min requirements met
262
if verbose and not result:
263
LOGGER.warning(s)
264
return result
265
266
267
@try_except
268
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
269
# Check installed dependencies meet requirements (pass *.txt file or list of packages)
270
prefix = colorstr('red', 'bold', 'requirements:')
271
check_python() # check python version
272
if isinstance(requirements, (str, Path)): # requirements.txt file
273
file = Path(requirements)
274
assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
275
with file.open() as f:
276
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
277
else: # list or tuple of packages
278
requirements = [x for x in requirements if x not in exclude]
279
280
n = 0 # number of packages updates
281
for r in requirements:
282
try:
283
pkg.require(r)
284
except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
285
s = f"{prefix} {r} not found and is required by YOLOv5"
286
if install:
287
print(f"{s}, attempting auto-update...")
288
try:
289
assert check_online(), f"'pip install {r}' skipped (offline)"
290
print(check_output(f"pip install '{r}'", shell=True).decode())
291
n += 1
292
except Exception as e:
293
print(f'{prefix} {e}')
294
else:
295
print(f'{s}. Please install and rerun your command.')
296
297
if n: # if packages updated
298
source = file.resolve() if 'file' in locals() else requirements
299
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
300
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
301
print(emojis(s))
302
303
304
def check_img_size(imgsz, s=32, floor=0):
305
# Verify image size is a multiple of stride s in each dimension
306
if isinstance(imgsz, int): # integer i.e. img_size=640
307
new_size = max(make_divisible(imgsz, int(s)), floor)
308
else: # list i.e. img_size=[640, 480]
309
new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
310
if new_size != imgsz:
311
print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
312
return new_size
313
314
315
def check_imshow():
316
# Check if environment supports image displays
317
try:
318
assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
319
assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
320
cv2.imshow('test', np.zeros((1, 1, 3)))
321
cv2.waitKey(1)
322
cv2.destroyAllWindows()
323
cv2.waitKey(1)
324
return True
325
except Exception as e:
326
print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
327
return False
328
329
330
def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
331
# Check file(s) for acceptable suffix
332
if file and suffix:
333
if isinstance(suffix, str):
334
suffix = [suffix]
335
for f in file if isinstance(file, (list, tuple)) else [file]:
336
s = Path(f).suffix.lower() # file suffix
337
if len(s):
338
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
339
340
341
def check_yaml(file, suffix=('.yaml', '.yml')):
342
# Search/download YAML file (if necessary) and return path, checking suffix
343
return check_file(file, suffix)
344
345
346
def check_file(file, suffix=''):
347
# Search/download file (if necessary) and return path
348
check_suffix(file, suffix) # optional
349
file = str(file) # convert to str()
350
if Path(file).is_file() or file == '': # exists
351
return file
352
elif file.startswith(('http:/', 'https:/')): # download
353
url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
354
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
355
if Path(file).is_file():
356
print(f'Found {url} locally at {file}') # file already exists
357
else:
358
print(f'Downloading {url} to {file}...')
359
torch.hub.download_url_to_file(url, file)
360
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
361
return file
362
else: # search
363
files = []
364
for d in 'data', 'models', 'utils': # search directories
365
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
366
assert len(files), f'File not found: {file}' # assert file was found
367
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
368
return files[0] # return file
369
370
371
def check_dataset(data, autodownload=True):
372
# Download and/or unzip dataset if not found locally
373
# Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
374
375
# Download (optional)
376
extract_dir = ''
377
if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
378
download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
379
data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
380
extract_dir, autodownload = data.parent, False
381
382
# Read yaml (optional)
383
if isinstance(data, (str, Path)):
384
with open(data, errors='ignore') as f:
385
data = yaml.safe_load(f) # dictionary
386
387
# Parse yaml
388
path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
389
for k in 'train', 'val', 'test':
390
if data.get(k): # prepend path
391
data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
392
393
assert 'nc' in data, "Dataset 'nc' key missing."
394
if 'names' not in data:
395
data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
396
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
397
if val:
398
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
399
if not all(x.exists() for x in val):
400
print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
401
if s and autodownload: # download script
402
root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
403
if s.startswith('http') and s.endswith('.zip'): # URL
404
f = Path(s).name # filename
405
print(f'Downloading {s} to {f}...')
406
torch.hub.download_url_to_file(s, f)
407
Path(root).mkdir(parents=True, exist_ok=True) # create root
408
ZipFile(f).extractall(path=root) # unzip
409
Path(f).unlink() # remove zip
410
r = None # success
411
elif s.startswith('bash '): # bash script
412
print(f'Running {s} ...')
413
r = os.system(s)
414
else: # python script
415
r = exec(s, {'yaml': data}) # return None
416
print(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n")
417
else:
418
raise Exception('Dataset not found.')
419
420
return data # dictionary
421
422
423
def url2file(url):
424
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
425
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
426
file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
427
return file
428
429
430
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
431
# Multi-threaded file download and unzip function, used in data.yaml for autodownload
432
def download_one(url, dir):
433
# Download 1 file
434
f = dir / Path(url).name # filename
435
if Path(url).is_file(): # exists in current path
436
Path(url).rename(f) # move to dir
437
elif not f.exists():
438
print(f'Downloading {url} to {f}...')
439
if curl:
440
os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
441
else:
442
torch.hub.download_url_to_file(url, f, progress=True) # torch download
443
if unzip and f.suffix in ('.zip', '.gz'):
444
print(f'Unzipping {f}...')
445
if f.suffix == '.zip':
446
ZipFile(f).extractall(path=dir) # unzip
447
elif f.suffix == '.gz':
448
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
449
if delete:
450
f.unlink() # remove zip
451
452
dir = Path(dir)
453
dir.mkdir(parents=True, exist_ok=True) # make directory
454
if threads > 1:
455
pool = ThreadPool(threads)
456
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
457
pool.close()
458
pool.join()
459
else:
460
for u in [url] if isinstance(url, (str, Path)) else url:
461
download_one(u, dir)
462
463
464
def make_divisible(x, divisor):
465
# Returns nearest x divisible by divisor
466
if isinstance(divisor, torch.Tensor):
467
divisor = int(divisor.max()) # to int
468
return math.ceil(x / divisor) * divisor
469
470
471
def clean_str(s):
472
# Cleans a string by replacing special characters with underscore _
473
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
474
475
476
def one_cycle(y1=0.0, y2=1.0, steps=100):
477
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
478
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
479
480
481
def colorstr(*input):
482
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
483
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
484
colors = {'black': '\033[30m', # basic colors
485
'red': '\033[31m',
486
'green': '\033[32m',
487
'yellow': '\033[33m',
488
'blue': '\033[34m',
489
'magenta': '\033[35m',
490
'cyan': '\033[36m',
491
'white': '\033[37m',
492
'bright_black': '\033[90m', # bright colors
493
'bright_red': '\033[91m',
494
'bright_green': '\033[92m',
495
'bright_yellow': '\033[93m',
496
'bright_blue': '\033[94m',
497
'bright_magenta': '\033[95m',
498
'bright_cyan': '\033[96m',
499
'bright_white': '\033[97m',
500
'end': '\033[0m', # misc
501
'bold': '\033[1m',
502
'underline': '\033[4m'}
503
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
504
505
506
def labels_to_class_weights(labels, nc=80):
507
# Get class weights (inverse frequency) from training labels
508
if labels[0] is None: # no labels loaded
509
return torch.Tensor()
510
511
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
512
classes = labels[:, 0].astype(np.int) # labels = [class xywh]
513
weights = np.bincount(classes, minlength=nc) # occurrences per class
514
515
# Prepend gridpoint count (for uCE training)
516
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
517
# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
518
519
weights[weights == 0] = 1 # replace empty bins with 1
520
weights = 1 / weights # number of targets per class
521
weights /= weights.sum() # normalize
522
return torch.from_numpy(weights)
523
524
525
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
526
# Produces image weights based on class_weights and image contents
527
class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
528
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
529
# index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
530
return image_weights
531
532
533
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
534
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
535
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
536
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
537
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
538
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
539
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
540
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
541
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
542
return x
543
544
545
def xyxy2xywh(x):
546
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
547
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
548
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
549
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
550
y[:, 2] = x[:, 2] - x[:, 0] # width
551
y[:, 3] = x[:, 3] - x[:, 1] # height
552
return y
553
554
555
def xywh2xyxy(x):
556
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
557
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
558
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
559
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
560
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
561
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
562
return y
563
564
565
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
566
# Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
567
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
568
y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
569
y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
570
y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
571
y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
572
return y
573
574
575
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
576
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
577
if clip:
578
clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
579
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
580
y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
581
y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
582
y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
583
y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
584
return y
585
586
587
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
588
# Convert normalized segments into pixel segments, shape (n,2)
589
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
590
y[:, 0] = w * x[:, 0] + padw # top left x
591
y[:, 1] = h * x[:, 1] + padh # top left y
592
return y
593
594
595
def segment2box(segment, width=640, height=640):
596
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
597
x, y = segment.T # segment xy
598
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
599
x, y, = x[inside], y[inside]
600
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
601
602
603
def segments2boxes(segments):
604
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
605
boxes = []
606
for s in segments:
607
x, y = s.T # segment xy
608
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
609
return xyxy2xywh(np.array(boxes)) # cls, xywh
610
611
612
def resample_segments(segments, n=1000):
613
# Up-sample an (n,2) segment
614
for i, s in enumerate(segments):
615
x = np.linspace(0, len(s) - 1, n)
616
xp = np.arange(len(s))
617
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
618
return segments
619
620
621
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
622
# Rescale coords (xyxy) from img1_shape to img0_shape
623
if ratio_pad is None: # calculate from img0_shape
624
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
625
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
626
else:
627
gain = ratio_pad[0][0]
628
pad = ratio_pad[1]
629
630
coords[:, [0, 2]] -= pad[0] # x padding
631
coords[:, [1, 3]] -= pad[1] # y padding
632
coords[:, :4] /= gain
633
clip_coords(coords, img0_shape)
634
return coords
635
636
def scale_polys(img1_shape, polys, img0_shape, ratio_pad=None):
637
# ratio_pad: [(h_raw, w_raw), (hw_ratios, wh_paddings)]
638
# Rescale coords (xyxyxyxy) from img1_shape to img0_shape
639
if ratio_pad is None: # calculate from img0_shape
640
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = resized / raw
641
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
642
else:
643
gain = ratio_pad[0][0] # h_ratios
644
pad = ratio_pad[1] # wh_paddings
645
646
polys[:, [0, 2, 4, 6]] -= pad[0] # x padding
647
polys[:, [1, 3, 5, 7]] -= pad[1] # y padding
648
polys[:, :8] /= gain # Rescale poly shape to img0_shape
649
#clip_polys(polys, img0_shape)
650
return polys
651
652
def clip_polys(polys, shape):
653
# Clip bounding xyxyxyxy bounding boxes to image shape (height, width)
654
if isinstance(polys, torch.Tensor): # faster individually
655
polys[:, 0].clamp_(0, shape[1]) # x1
656
polys[:, 1].clamp_(0, shape[0]) # y1
657
polys[:, 2].clamp_(0, shape[1]) # x2
658
polys[:, 3].clamp_(0, shape[0]) # y2
659
polys[:, 4].clamp_(0, shape[1]) # x3
660
polys[:, 5].clamp_(0, shape[0]) # y3
661
polys[:, 6].clamp_(0, shape[1]) # x4
662
polys[:, 7].clamp_(0, shape[0]) # y4
663
else: # np.array (faster grouped)
664
polys[:, [0, 2, 4, 6]] = polys[:, [0, 2, 4, 6]].clip(0, shape[1]) # x1, x2, x3, x4
665
polys[:, [1, 3, 5, 7]] = polys[:, [1, 3, 5, 7]].clip(0, shape[0]) # y1, y2, y3, y4
666
667
def clip_coords(boxes, shape):
668
# Clip bounding xyxy bounding boxes to image shape (height, width)
669
if isinstance(boxes, torch.Tensor): # faster individually
670
boxes[:, 0].clamp_(0, shape[1]) # x1
671
boxes[:, 1].clamp_(0, shape[0]) # y1
672
boxes[:, 2].clamp_(0, shape[1]) # x2
673
boxes[:, 3].clamp_(0, shape[0]) # y2
674
else: # np.array (faster grouped)
675
boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
676
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
677
678
679
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
680
labels=(), max_det=300):
681
"""Runs Non-Maximum Suppression (NMS) on inference results
682
683
Returns:
684
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
685
"""
686
687
nc = prediction.shape[2] - 5 # number of classes
688
xc = prediction[..., 4] > conf_thres # candidates
689
690
# Checks
691
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
692
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
693
694
# Settings
695
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
696
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
697
time_limit = 10.0 # seconds to quit after
698
redundant = True # require redundant detections
699
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
700
merge = False # use merge-NMS
701
702
t = time.time()
703
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
704
for xi, x in enumerate(prediction): # image index, image inference
705
# Apply constraints
706
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
707
x = x[xc[xi]] # confidence
708
709
# Cat apriori labels if autolabelling
710
if labels and len(labels[xi]):
711
l = labels[xi]
712
v = torch.zeros((len(l), nc + 5), device=x.device)
713
v[:, :4] = l[:, 1:5] # box
714
v[:, 4] = 1.0 # conf
715
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
716
x = torch.cat((x, v), 0)
717
718
# If none remain process next image
719
if not x.shape[0]:
720
continue
721
722
# Compute conf
723
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
724
725
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
726
box = xywh2xyxy(x[:, :4])
727
728
# Detections matrix nx6 (xyxy, conf, cls)
729
if multi_label:
730
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
731
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
732
else: # best class only
733
conf, j = x[:, 5:].max(1, keepdim=True)
734
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
735
736
# Filter by class
737
if classes is not None:
738
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
739
740
# Apply finite constraint
741
# if not torch.isfinite(x).all():
742
# x = x[torch.isfinite(x).all(1)]
743
744
# Check shape
745
n = x.shape[0] # number of boxes
746
if not n: # no boxes
747
continue
748
elif n > max_nms: # excess boxes
749
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
750
751
# Batched NMS
752
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
753
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
754
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
755
if i.shape[0] > max_det: # limit detections
756
i = i[:max_det]
757
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
758
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
759
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
760
weights = iou * scores[None] # box weights
761
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
762
if redundant:
763
i = i[iou.sum(1) > 1] # require redundancy
764
765
output[xi] = x[i]
766
if (time.time() - t) > time_limit:
767
print(f'WARNING: NMS time limit {time_limit}s exceeded')
768
break # time limit exceeded
769
770
return output
771
772
def non_max_suppression_obb(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
773
labels=(), max_det=1500):
774
"""Runs Non-Maximum Suppression (NMS) on inference results_obb
775
Args:
776
prediction (tensor): (b, n_all_anchors, [cx cy l s obj num_cls theta_cls])
777
agnostic (bool): True = NMS will be applied between elements of different categories
778
labels : () or
779
780
Returns:
781
list of detections, len=batch_size, on (n,7) tensor per image [xylsθ, conf, cls] θ ∈ [-pi/2, pi/2)
782
"""
783
784
nc = prediction.shape[2] - 5 - 180 # number of classes
785
xc = prediction[..., 4] > conf_thres # candidates
786
class_index = nc + 5
787
788
# Checks
789
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
790
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
791
792
# Settings
793
max_wh = 4096 # min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
794
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
795
time_limit = 30.0 # seconds to quit after
796
# redundant = True # require redundant detections
797
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
798
799
t = time.time()
800
output = [torch.zeros((0, 7), device=prediction.device)] * prediction.shape[0]
801
for xi, x in enumerate(prediction): # image index, image inference
802
# Apply constraints
803
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
804
x = x[xc[xi]] # confidence, (tensor): (n_conf_thres, [cx cy l s obj num_cls theta_cls])
805
806
# Cat apriori labels if autolabelling
807
if labels and len(labels[xi]):
808
l = labels[xi]
809
v = torch.zeros((len(l), nc + 5), device=x.device)
810
v[:, :4] = l[:, 1:5] # box
811
v[:, 4] = 1.0 # conf
812
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
813
x = torch.cat((x, v), 0)
814
815
# If none remain process next image
816
if not x.shape[0]:
817
continue
818
819
# Compute conf
820
x[:, 5:class_index] *= x[:, 4:5] # conf = obj_conf * cls_conf
821
822
_, theta_pred = torch.max(x[:, class_index:], 1, keepdim=True) # [n_conf_thres, 1] θ ∈ int[0, 179]
823
theta_pred = (theta_pred - 90) / 180 * pi # [n_conf_thres, 1] θ ∈ [-pi/2, pi/2)
824
825
# Detections matrix nx7 (xyls, θ, conf, cls) θ ∈ [-pi/2, pi/2)
826
if multi_label:
827
i, j = (x[:, 5:class_index] > conf_thres).nonzero(as_tuple=False).T # ()
828
x = torch.cat((x[i, :4], theta_pred[i], x[i, j + 5, None], j[:, None].float()), 1)
829
else: # best class only
830
conf, j = x[:, 5:class_index].max(1, keepdim=True)
831
x = torch.cat((x[:, :4], theta_pred, conf, j.float()), 1)[conf.view(-1) > conf_thres]
832
833
# Filter by class
834
if classes is not None:
835
x = x[(x[:, 6:7] == torch.tensor(classes, device=x.device)).any(1)]
836
837
# Apply finite constraint
838
# if not torch.isfinite(x).all():
839
# x = x[torch.isfinite(x).all(1)]
840
841
# Check shape
842
n = x.shape[0] # number of boxes
843
if not n: # no boxes
844
continue
845
elif n > max_nms: # excess boxes
846
x = x[x[:, 5].argsort(descending=True)[:max_nms]] # sort by confidence
847
848
# Batched NMS
849
c = x[:, 6:7] * (0 if agnostic else max_wh) # classes
850
rboxes = x[:, :5].clone()
851
rboxes[:, :2] = rboxes[:, :2] + c # rboxes (offset by class)
852
scores = x[:, 5] # scores
853
_, i = obb_nms(rboxes, scores, iou_thres)
854
if i.shape[0] > max_det: # limit detections
855
i = i[:max_det]
856
857
output[xi] = x[i]
858
if (time.time() - t) > time_limit:
859
print(f'WARNING: NMS time limit {time_limit}s exceeded')
860
break # time limit exceeded
861
862
return output
863
864
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
865
# Strip optimizer from 'f' to finalize training, optionally save as 's'
866
x = torch.load(f, map_location=torch.device('cpu'))
867
if x.get('ema'):
868
x['model'] = x['ema'] # replace model with ema
869
for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
870
x[k] = None
871
x['epoch'] = -1
872
x['model'].half() # to FP16
873
for p in x['model'].parameters():
874
p.requires_grad = False
875
torch.save(x, s or f)
876
mb = os.path.getsize(s or f) / 1E6 # filesize
877
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
878
879
880
def print_mutation(results, hyp, save_dir, bucket):
881
evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
882
keys = ('metrics/precision', 'metrics/recall', 'metrics/HBBmAP.5', 'metrics/HBBmAP.5:.95',
883
'val/box_loss', 'val/obj_loss', 'val/cls_loss', 'val/theta_loss') + tuple(hyp.keys()) # [results + hyps]
884
keys = tuple(x.strip() for x in keys)
885
vals = results + tuple(hyp.values())
886
n = len(keys)
887
888
# Download (optional)
889
if bucket:
890
url = f'gs://{bucket}/evolve.csv'
891
if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
892
os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
893
894
# Log to evolve.csv
895
s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
896
with open(evolve_csv, 'a') as f:
897
f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
898
899
# Print to screen
900
print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
901
print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
902
903
# Save yaml
904
with open(evolve_yaml, 'w') as f:
905
data = pd.read_csv(evolve_csv)
906
data = data.rename(columns=lambda x: x.strip()) # strip keys
907
i = np.argmax(fitness(data.values[:, :7])) #
908
f.write('# YOLOv5 Hyperparameter Evolution Results\n' +
909
f'# Best generation: {i}\n' +
910
f'# Last generation: {len(data) - 1}\n' +
911
'# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
912
'# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
913
yaml.safe_dump(hyp, f, sort_keys=False)
914
915
if bucket:
916
os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
917
918
919
def apply_classifier(x, model, img, im0):
920
# Apply a second stage classifier to YOLO outputs
921
# Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
922
im0 = [im0] if isinstance(im0, np.ndarray) else im0
923
for i, d in enumerate(x): # per image
924
if d is not None and len(d):
925
d = d.clone()
926
927
# Reshape and pad cutouts
928
b = xyxy2xywh(d[:, :4]) # boxes
929
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
930
b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
931
d[:, :4] = xywh2xyxy(b).long()
932
933
# Rescale boxes from img_size to im0 size
934
scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
935
936
# Classes
937
pred_cls1 = d[:, 5].long()
938
ims = []
939
for j, a in enumerate(d): # per item
940
cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
941
im = cv2.resize(cutout, (224, 224)) # BGR
942
# cv2.imwrite('example%i.jpg' % j, cutout)
943
944
im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
945
im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
946
im /= 255 # 0 - 255 to 0.0 - 1.0
947
ims.append(im)
948
949
pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
950
x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
951
952
return x
953
954
955
def increment_path(path, exist_ok=False, sep='', mkdir=False):
956
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
957
path = Path(path) # os-agnostic
958
if path.exists() and not exist_ok:
959
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
960
dirs = glob.glob(f"{path}{sep}*") # similar paths
961
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
962
i = [int(m.groups()[0]) for m in matches if m] # indices
963
n = max(i) + 1 if i else 2 # increment number
964
path = Path(f"{path}{sep}{n}{suffix}") # increment path
965
if mkdir:
966
path.mkdir(parents=True, exist_ok=True) # make directory
967
return path
968
969
970
# Variables
971
NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm
972
973