CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/utils/torch_utils.py
Views: 475
1
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
"""
3
PyTorch utils
4
"""
5
6
import datetime
7
import math
8
import os
9
import platform
10
import subprocess
11
import time
12
from contextlib import contextmanager
13
from copy import deepcopy
14
from pathlib import Path
15
16
import torch
17
import torch.distributed as dist
18
import torch.nn as nn
19
import torch.nn.functional as F
20
21
from utils.general import LOGGER
22
23
try:
24
import thop # for FLOPs computation
25
except ImportError:
26
thop = None
27
28
29
@contextmanager
30
def torch_distributed_zero_first(local_rank: int):
31
"""
32
Decorator to make all processes in distributed training wait for each local_master to do something.
33
"""
34
if local_rank not in [-1, 0]:
35
dist.barrier(device_ids=[local_rank])
36
yield
37
if local_rank == 0:
38
dist.barrier(device_ids=[0])
39
40
41
def date_modified(path=__file__):
42
# return human-readable file modification date, i.e. '2021-3-26'
43
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
44
return f'{t.year}-{t.month}-{t.day}'
45
46
47
def git_describe(path=Path(__file__).parent): # path must be a directory
48
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
49
s = f'git -C {path} describe --tags --long --always'
50
try:
51
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
52
except subprocess.CalledProcessError as e:
53
return '' # not a git repository
54
55
56
def select_device(device='', batch_size=0, newline=True):
57
# device = 'cpu' or '0' or '0,1,2,3'
58
s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
59
device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
60
cpu = device == 'cpu'
61
if cpu:
62
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
63
elif device: # non-cpu device requested
64
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
65
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
66
67
cuda = not cpu and torch.cuda.is_available()
68
if cuda:
69
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
70
n = len(devices) # device count
71
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
72
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
73
space = ' ' * (len(s) + 1)
74
for i, d in enumerate(devices):
75
p = torch.cuda.get_device_properties(i)
76
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2:.0f}MiB)\n" # bytes to MB
77
else:
78
s += 'CPU\n'
79
80
if not newline:
81
s = s.rstrip()
82
LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
83
return torch.device('cuda:0' if cuda else 'cpu')
84
85
86
def time_sync():
87
# pytorch-accurate time
88
if torch.cuda.is_available():
89
torch.cuda.synchronize()
90
return time.time()
91
92
93
def profile(input, ops, n=10, device=None):
94
# YOLOv5 speed/memory/FLOPs profiler
95
#
96
# Usage:
97
# input = torch.randn(16, 3, 640, 640)
98
# m1 = lambda x: x * torch.sigmoid(x)
99
# m2 = nn.SiLU()
100
# profile(input, [m1, m2], n=100) # profile over 100 iterations
101
102
results = []
103
device = device or select_device()
104
print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
105
f"{'input':>24s}{'output':>24s}")
106
107
for x in input if isinstance(input, list) else [input]:
108
x = x.to(device)
109
x.requires_grad = True
110
for m in ops if isinstance(ops, list) else [ops]:
111
m = m.to(device) if hasattr(m, 'to') else m # device
112
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
113
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
114
try:
115
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
116
except:
117
flops = 0
118
119
try:
120
for _ in range(n):
121
t[0] = time_sync()
122
y = m(x)
123
t[1] = time_sync()
124
try:
125
_ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
126
t[2] = time_sync()
127
except Exception as e: # no backward method
128
# print(e) # for debug
129
t[2] = float('nan')
130
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
131
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
132
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
133
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
134
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
135
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
136
print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
137
results.append([p, flops, mem, tf, tb, s_in, s_out])
138
except Exception as e:
139
print(e)
140
results.append(None)
141
torch.cuda.empty_cache()
142
return results
143
144
145
def is_parallel(model):
146
# Returns True if model is of type DP or DDP
147
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
148
149
150
def de_parallel(model):
151
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
152
return model.module if is_parallel(model) else model
153
154
155
def initialize_weights(model):
156
for m in model.modules():
157
t = type(m)
158
if t is nn.Conv2d:
159
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
160
elif t is nn.BatchNorm2d:
161
m.eps = 1e-3
162
m.momentum = 0.03
163
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
164
m.inplace = True
165
166
167
def find_modules(model, mclass=nn.Conv2d):
168
# Finds layer indices matching module class 'mclass'
169
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
170
171
172
def sparsity(model):
173
# Return global model sparsity
174
a, b = 0, 0
175
for p in model.parameters():
176
a += p.numel()
177
b += (p == 0).sum()
178
return b / a
179
180
181
def prune(model, amount=0.3):
182
# Prune model to requested global sparsity
183
import torch.nn.utils.prune as prune
184
print('Pruning model... ', end='')
185
for name, m in model.named_modules():
186
if isinstance(m, nn.Conv2d):
187
prune.l1_unstructured(m, name='weight', amount=amount) # prune
188
prune.remove(m, 'weight') # make permanent
189
print(' %.3g global sparsity' % sparsity(model))
190
191
192
def fuse_conv_and_bn(conv, bn):
193
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
194
fusedconv = nn.Conv2d(conv.in_channels,
195
conv.out_channels,
196
kernel_size=conv.kernel_size,
197
stride=conv.stride,
198
padding=conv.padding,
199
groups=conv.groups,
200
bias=True).requires_grad_(False).to(conv.weight.device)
201
202
# prepare filters
203
w_conv = conv.weight.clone().view(conv.out_channels, -1)
204
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
205
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
206
207
# prepare spatial bias
208
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
209
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
210
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
211
212
return fusedconv
213
214
215
def model_info(model, verbose=False, img_size=640):
216
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
217
n_p = sum(x.numel() for x in model.parameters()) # number parameters
218
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
219
if verbose:
220
print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
221
for i, (name, p) in enumerate(model.named_parameters()):
222
name = name.replace('module_list.', '')
223
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
224
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
225
226
try: # FLOPs
227
from thop import profile
228
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
229
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
230
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
231
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
232
fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
233
except (ImportError, Exception):
234
fs = ''
235
236
LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
237
238
239
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
240
# scales img(bs,3,y,x) by ratio constrained to gs-multiple
241
if ratio == 1.0:
242
return img
243
else:
244
h, w = img.shape[2:]
245
s = (int(h * ratio), int(w * ratio)) # new size
246
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
247
if not same_shape: # pad/crop img
248
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
249
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
250
251
252
def copy_attr(a, b, include=(), exclude=()):
253
# Copy attributes from b to a, options to only include [...] and to exclude [...]
254
for k, v in b.__dict__.items():
255
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
256
continue
257
else:
258
setattr(a, k, v)
259
260
261
class EarlyStopping:
262
# YOLOv5 simple early stopper
263
def __init__(self, patience=30):
264
self.best_fitness = 0.0 # i.e. mAP
265
self.best_epoch = 0
266
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
267
self.possible_stop = False # possible stop may occur next epoch
268
269
def __call__(self, epoch, fitness):
270
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
271
self.best_epoch = epoch
272
self.best_fitness = fitness
273
delta = epoch - self.best_epoch # epochs without improvement
274
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
275
stop = delta >= self.patience # stop training if patience exceeded
276
if stop:
277
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
278
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
279
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
280
f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
281
return stop
282
283
284
class ModelEMA:
285
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
286
Keep a moving average of everything in the model state_dict (parameters and buffers).
287
This is intended to allow functionality like
288
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
289
A smoothed version of the weights is necessary for some training schemes to perform well.
290
This class is sensitive where it is initialized in the sequence of model init,
291
GPU assignment and distributed training wrappers.
292
"""
293
294
def __init__(self, model, decay=0.9999, updates=0):
295
# Create EMA
296
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
297
# if next(model.parameters()).device.type != 'cpu':
298
# self.ema.half() # FP16 EMA
299
self.updates = updates # number of EMA updates
300
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
301
for p in self.ema.parameters():
302
p.requires_grad_(False)
303
304
def update(self, model):
305
# Update EMA parameters
306
with torch.no_grad():
307
self.updates += 1
308
d = self.decay(self.updates)
309
310
msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
311
for k, v in self.ema.state_dict().items():
312
if v.dtype.is_floating_point:
313
v *= d
314
v += (1 - d) * msd[k].detach()
315
316
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
317
# Update EMA attributes
318
copy_attr(self.ema, model, include, exclude)
319
320