CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/models/tf.py
Views: 475
1
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
"""
3
TensorFlow, Keras and TFLite versions of YOLOv5
4
Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127
5
6
Usage:
7
$ python models/tf.py --weights yolov5s.pt
8
9
Export:
10
$ python path/to/export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
11
"""
12
13
import argparse
14
import sys
15
from copy import deepcopy
16
from pathlib import Path
17
18
FILE = Path(__file__).resolve()
19
ROOT = FILE.parents[1] # YOLOv5 root directory
20
if str(ROOT) not in sys.path:
21
sys.path.append(str(ROOT)) # add ROOT to PATH
22
# ROOT = ROOT.relative_to(Path.cwd()) # relative
23
24
import numpy as np
25
import tensorflow as tf
26
import torch
27
import torch.nn as nn
28
from tensorflow import keras
29
30
from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, Concat, Conv, DWConv, Focus, autopad
31
from models.experimental import CrossConv, MixConv2d, attempt_load
32
from models.yolo import Detect
33
from utils.activations import SiLU
34
from utils.general import LOGGER, make_divisible, print_args
35
36
37
class TFBN(keras.layers.Layer):
38
# TensorFlow BatchNormalization wrapper
39
def __init__(self, w=None):
40
super().__init__()
41
self.bn = keras.layers.BatchNormalization(
42
beta_initializer=keras.initializers.Constant(w.bias.numpy()),
43
gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
44
moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
45
moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
46
epsilon=w.eps)
47
48
def call(self, inputs):
49
return self.bn(inputs)
50
51
52
class TFPad(keras.layers.Layer):
53
def __init__(self, pad):
54
super().__init__()
55
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
56
57
def call(self, inputs):
58
return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
59
60
61
class TFConv(keras.layers.Layer):
62
# Standard convolution
63
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
64
# ch_in, ch_out, weights, kernel, stride, padding, groups
65
super().__init__()
66
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
67
assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
68
# TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
69
# see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
70
71
conv = keras.layers.Conv2D(
72
c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False if hasattr(w, 'bn') else True,
73
kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
74
bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
75
self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
76
self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
77
78
# YOLOv5 activations
79
if isinstance(w.act, nn.LeakyReLU):
80
self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
81
elif isinstance(w.act, nn.Hardswish):
82
self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
83
elif isinstance(w.act, (nn.SiLU, SiLU)):
84
self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
85
else:
86
raise Exception(f'no matching TensorFlow activation found for {w.act}')
87
88
def call(self, inputs):
89
return self.act(self.bn(self.conv(inputs)))
90
91
92
class TFFocus(keras.layers.Layer):
93
# Focus wh information into c-space
94
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
95
# ch_in, ch_out, kernel, stride, padding, groups
96
super().__init__()
97
self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
98
99
def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
100
# inputs = inputs / 255 # normalize 0-255 to 0-1
101
return self.conv(tf.concat([inputs[:, ::2, ::2, :],
102
inputs[:, 1::2, ::2, :],
103
inputs[:, ::2, 1::2, :],
104
inputs[:, 1::2, 1::2, :]], 3))
105
106
107
class TFBottleneck(keras.layers.Layer):
108
# Standard bottleneck
109
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
110
super().__init__()
111
c_ = int(c2 * e) # hidden channels
112
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
113
self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
114
self.add = shortcut and c1 == c2
115
116
def call(self, inputs):
117
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
118
119
120
class TFConv2d(keras.layers.Layer):
121
# Substitution for PyTorch nn.Conv2D
122
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
123
super().__init__()
124
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
125
self.conv = keras.layers.Conv2D(
126
c2, k, s, 'VALID', use_bias=bias,
127
kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
128
bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None, )
129
130
def call(self, inputs):
131
return self.conv(inputs)
132
133
134
class TFBottleneckCSP(keras.layers.Layer):
135
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
136
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
137
# ch_in, ch_out, number, shortcut, groups, expansion
138
super().__init__()
139
c_ = int(c2 * e) # hidden channels
140
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
141
self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
142
self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
143
self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
144
self.bn = TFBN(w.bn)
145
self.act = lambda x: keras.activations.relu(x, alpha=0.1)
146
self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
147
148
def call(self, inputs):
149
y1 = self.cv3(self.m(self.cv1(inputs)))
150
y2 = self.cv2(inputs)
151
return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
152
153
154
class TFC3(keras.layers.Layer):
155
# CSP Bottleneck with 3 convolutions
156
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
157
# ch_in, ch_out, number, shortcut, groups, expansion
158
super().__init__()
159
c_ = int(c2 * e) # hidden channels
160
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
161
self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
162
self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
163
self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
164
165
def call(self, inputs):
166
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
167
168
169
class TFSPP(keras.layers.Layer):
170
# Spatial pyramid pooling layer used in YOLOv3-SPP
171
def __init__(self, c1, c2, k=(5, 9, 13), w=None):
172
super().__init__()
173
c_ = c1 // 2 # hidden channels
174
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
175
self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
176
self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
177
178
def call(self, inputs):
179
x = self.cv1(inputs)
180
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
181
182
183
class TFSPPF(keras.layers.Layer):
184
# Spatial pyramid pooling-Fast layer
185
def __init__(self, c1, c2, k=5, w=None):
186
super().__init__()
187
c_ = c1 // 2 # hidden channels
188
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
189
self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
190
self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
191
192
def call(self, inputs):
193
x = self.cv1(inputs)
194
y1 = self.m(x)
195
y2 = self.m(y1)
196
return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
197
198
199
class TFDetect(keras.layers.Layer):
200
def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
201
super().__init__()
202
self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
203
self.nc = nc # number of classes
204
# self.no = nc + 5 # number of outputs per anchor
205
self.no = nc + 5 + 180 # number of outputs per anchor
206
self.nl = len(anchors) # number of detection layers
207
self.na = len(anchors[0]) // 2 # number of anchors
208
self.grid = [tf.zeros(1)] * self.nl # init grid
209
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
210
self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]),
211
[self.nl, 1, -1, 1, 2])
212
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
213
self.training = False # set to False after building model
214
self.imgsz = imgsz
215
for i in range(self.nl):
216
ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
217
self.grid[i] = self._make_grid(nx, ny)
218
219
def call(self, inputs):
220
z = [] # inference output
221
x = []
222
for i in range(self.nl):
223
x.append(self.m[i](inputs[i]))
224
# x(bs,20,20,255) to x(bs,3,20,20,85)
225
ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
226
x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])
227
228
if not self.training: # inference
229
y = tf.sigmoid(x[i])
230
xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy
231
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
232
# Normalize xywh to 0-1 to reduce calibration error
233
xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
234
wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
235
y = tf.concat([xy, wh, y[..., 4:]], -1)
236
z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
237
238
return x if self.training else (tf.concat(z, 1), x)
239
240
@staticmethod
241
def _make_grid(nx=20, ny=20):
242
# yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
243
# return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
244
xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
245
return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
246
247
248
class TFUpsample(keras.layers.Layer):
249
def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
250
super().__init__()
251
assert scale_factor == 2, "scale_factor must be 2"
252
self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
253
# self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
254
# with default arguments: align_corners=False, half_pixel_centers=False
255
# self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
256
# size=(x.shape[1] * 2, x.shape[2] * 2))
257
258
def call(self, inputs):
259
return self.upsample(inputs)
260
261
262
class TFConcat(keras.layers.Layer):
263
def __init__(self, dimension=1, w=None):
264
super().__init__()
265
assert dimension == 1, "convert only NCHW to NHWC concat"
266
self.d = 3
267
268
def call(self, inputs):
269
return tf.concat(inputs, self.d)
270
271
272
def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
273
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
274
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
275
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
276
no = na * (nc + 5 + 180) # number of outputs = anchors * (classes + 5)
277
278
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
279
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
280
m_str = m
281
m = eval(m) if isinstance(m, str) else m # eval strings
282
for j, a in enumerate(args):
283
try:
284
args[j] = eval(a) if isinstance(a, str) else a # eval strings
285
except NameError:
286
pass
287
288
n = max(round(n * gd), 1) if n > 1 else n # depth gain
289
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
290
c1, c2 = ch[f], args[0]
291
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
292
293
args = [c1, c2, *args[1:]]
294
if m in [BottleneckCSP, C3]:
295
args.insert(2, n)
296
n = 1
297
elif m is nn.BatchNorm2d:
298
args = [ch[f]]
299
elif m is Concat:
300
c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
301
elif m is Detect:
302
args.append([ch[x + 1] for x in f])
303
if isinstance(args[1], int): # number of anchors
304
args[1] = [list(range(args[1] * 2))] * len(f)
305
args.append(imgsz)
306
else:
307
c2 = ch[f]
308
309
tf_m = eval('TF' + m_str.replace('nn.', ''))
310
m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
311
else tf_m(*args, w=model.model[i]) # module
312
313
torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
314
t = str(m)[8:-2].replace('__main__.', '') # module type
315
np = sum(x.numel() for x in torch_m_.parameters()) # number params
316
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
317
LOGGER.info(f'{i:>3}{str(f):>18}{str(n):>3}{np:>10} {t:<40}{str(args):<30}') # print
318
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
319
layers.append(m_)
320
ch.append(c2)
321
return keras.Sequential(layers), sorted(save)
322
323
324
class TFModel:
325
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes
326
super().__init__()
327
if isinstance(cfg, dict):
328
self.yaml = cfg # model dict
329
else: # is *.yaml
330
import yaml # for torch hub
331
self.yaml_file = Path(cfg).name
332
with open(cfg) as f:
333
self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
334
335
# Define model
336
if nc and nc != self.yaml['nc']:
337
LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
338
self.yaml['nc'] = nc # override yaml value
339
self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
340
341
def predict(self, inputs, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
342
conf_thres=0.25):
343
y = [] # outputs
344
x = inputs
345
for i, m in enumerate(self.model.layers):
346
if m.f != -1: # if not from previous layer
347
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
348
349
x = m(x) # run
350
y.append(x if m.i in self.savelist else None) # save output
351
352
# Add TensorFlow NMS
353
if tf_nms:
354
boxes = self._xywh2xyxy(x[0][..., :4])
355
probs = x[0][:, :, 4:5]
356
classes = x[0][:, :, 5:]
357
scores = probs * classes
358
if agnostic_nms:
359
nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
360
return nms, x[1]
361
else:
362
boxes = tf.expand_dims(boxes, 2)
363
nms = tf.image.combined_non_max_suppression(
364
boxes, scores, topk_per_class, topk_all, iou_thres, conf_thres, clip_boxes=False)
365
return nms, x[1]
366
367
return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
368
# x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
369
# xywh = x[..., :4] # x(6300,4) boxes
370
# conf = x[..., 4:5] # x(6300,1) confidences
371
# cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
372
# return tf.concat([conf, cls, xywh], 1)
373
374
@staticmethod
375
def _xywh2xyxy(xywh):
376
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
377
x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
378
return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
379
380
381
class AgnosticNMS(keras.layers.Layer):
382
# TF Agnostic NMS
383
def call(self, input, topk_all, iou_thres, conf_thres):
384
# wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
385
return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres), input,
386
fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
387
name='agnostic_nms')
388
389
@staticmethod
390
def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
391
boxes, classes, scores = x
392
class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
393
scores_inp = tf.reduce_max(scores, -1)
394
selected_inds = tf.image.non_max_suppression(
395
boxes, scores_inp, max_output_size=topk_all, iou_threshold=iou_thres, score_threshold=conf_thres)
396
selected_boxes = tf.gather(boxes, selected_inds)
397
padded_boxes = tf.pad(selected_boxes,
398
paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
399
mode="CONSTANT", constant_values=0.0)
400
selected_scores = tf.gather(scores_inp, selected_inds)
401
padded_scores = tf.pad(selected_scores,
402
paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
403
mode="CONSTANT", constant_values=-1.0)
404
selected_classes = tf.gather(class_inds, selected_inds)
405
padded_classes = tf.pad(selected_classes,
406
paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
407
mode="CONSTANT", constant_values=-1.0)
408
valid_detections = tf.shape(selected_inds)[0]
409
return padded_boxes, padded_scores, padded_classes, valid_detections
410
411
412
def representative_dataset_gen(dataset, ncalib=100):
413
# Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
414
for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
415
input = np.transpose(img, [1, 2, 0])
416
input = np.expand_dims(input, axis=0).astype(np.float32)
417
input /= 255
418
yield [input]
419
if n >= ncalib:
420
break
421
422
423
def run(weights=ROOT / 'yolov5s.pt', # weights path
424
imgsz=(640, 640), # inference size h,w
425
batch_size=1, # batch size
426
dynamic=False, # dynamic batch size
427
):
428
# PyTorch model
429
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
430
model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
431
y = model(im) # inference
432
model.info()
433
434
# TensorFlow model
435
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
436
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
437
y = tf_model.predict(im) # inference
438
439
# Keras model
440
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
441
keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
442
keras_model.summary()
443
444
LOGGER.info('PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.')
445
446
447
def parse_opt():
448
parser = argparse.ArgumentParser()
449
parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
450
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
451
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
452
parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
453
opt = parser.parse_args()
454
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
455
print_args(FILE.stem, opt)
456
return opt
457
458
459
def main(opt):
460
run(**vars(opt))
461
462
463
if __name__ == "__main__":
464
opt = parse_opt()
465
main(opt)
466
467