CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/export.py
Views: 475
1
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
"""
3
Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
4
5
Format | Example | `--include ...` argument
6
--- | --- | ---
7
PyTorch | yolov5s.pt | -
8
TorchScript | yolov5s.torchscript | `torchscript`
9
ONNX | yolov5s.onnx | `onnx`
10
CoreML | yolov5s.mlmodel | `coreml`
11
OpenVINO | yolov5s_openvino_model/ | `openvino`
12
TensorFlow SavedModel | yolov5s_saved_model/ | `saved_model`
13
TensorFlow GraphDef | yolov5s.pb | `pb`
14
TensorFlow Lite | yolov5s.tflite | `tflite`
15
TensorFlow.js | yolov5s_web_model/ | `tfjs`
16
TensorRT | yolov5s.engine | `engine`
17
18
Usage:
19
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml openvino saved_model tflite tfjs
20
21
Inference:
22
$ python path/to/detect.py --weights yolov5s.pt
23
yolov5s.torchscript
24
yolov5s.onnx
25
yolov5s.mlmodel (under development)
26
yolov5s_openvino_model (under development)
27
yolov5s_saved_model
28
yolov5s.pb
29
yolov5s.tflite
30
yolov5s.engine
31
32
TensorFlow.js:
33
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
34
$ npm install
35
$ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
36
$ npm start
37
"""
38
39
import argparse
40
import json
41
import os
42
import subprocess
43
import sys
44
import time
45
from pathlib import Path
46
47
import torch
48
import torch.nn as nn
49
from torch.utils.mobile_optimizer import optimize_for_mobile
50
51
FILE = Path(__file__).resolve()
52
ROOT = FILE.parents[0] # YOLOv5 root directory
53
if str(ROOT) not in sys.path:
54
sys.path.append(str(ROOT)) # add ROOT to PATH
55
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
56
57
from models.common import Conv
58
from models.experimental import attempt_load
59
from models.yolo import Detect
60
from utils.activations import SiLU
61
from utils.datasets import LoadImages
62
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, colorstr, file_size, print_args,
63
url2file)
64
from utils.torch_utils import select_device
65
66
67
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
68
# YOLOv5 TorchScript model export
69
try:
70
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
71
f = file.with_suffix('.torchscript')
72
73
ts = torch.jit.trace(model, im, strict=False)
74
d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
75
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
76
(optimize_for_mobile(ts) if optimize else ts).save(str(f), _extra_files=extra_files)
77
78
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
79
except Exception as e:
80
LOGGER.info(f'{prefix} export failure: {e}')
81
82
83
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
84
# YOLOv5 ONNX export
85
try:
86
check_requirements(('onnx',))
87
import onnx
88
89
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
90
f = file.with_suffix('.onnx')
91
92
torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
93
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
94
do_constant_folding=not train,
95
input_names=['images'],
96
output_names=['output'],
97
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
98
'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
99
} if dynamic else None)
100
101
# Checks
102
model_onnx = onnx.load(f) # load onnx model
103
onnx.checker.check_model(model_onnx) # check onnx model
104
# LOGGER.info(onnx.helper.printable_graph(model_onnx.graph)) # print
105
106
# Simplify
107
if simplify:
108
try:
109
check_requirements(('onnx-simplifier',))
110
import onnxsim
111
112
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
113
model_onnx, check = onnxsim.simplify(
114
model_onnx,
115
dynamic_input_shape=dynamic,
116
input_shapes={'images': list(im.shape)} if dynamic else None)
117
assert check, 'assert check failed'
118
onnx.save(model_onnx, f)
119
except Exception as e:
120
LOGGER.info(f'{prefix} simplifier failure: {e}')
121
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
122
LOGGER.info(f"{prefix} run --dynamic ONNX model inference with: 'python detect.py --weights {f}'")
123
except Exception as e:
124
LOGGER.info(f'{prefix} export failure: {e}')
125
126
127
def export_coreml(model, im, file, prefix=colorstr('CoreML:')):
128
# YOLOv5 CoreML export
129
ct_model = None
130
try:
131
check_requirements(('coremltools',))
132
import coremltools as ct
133
134
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
135
f = file.with_suffix('.mlmodel')
136
137
model.train() # CoreML exports should be placed in model.train() mode
138
ts = torch.jit.trace(model, im, strict=False) # TorchScript model
139
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
140
ct_model.save(f)
141
142
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
143
except Exception as e:
144
LOGGER.info(f'\n{prefix} export failure: {e}')
145
146
return ct_model
147
148
149
def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')):
150
# YOLOv5 OpenVINO export
151
try:
152
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
153
import openvino.inference_engine as ie
154
155
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
156
f = str(file).replace('.pt', '_openvino_model' + os.sep)
157
158
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f}"
159
subprocess.check_output(cmd, shell=True)
160
161
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
162
except Exception as e:
163
LOGGER.info(f'\n{prefix} export failure: {e}')
164
165
166
def export_saved_model(model, im, file, dynamic,
167
tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
168
conf_thres=0.25, prefix=colorstr('TensorFlow saved_model:')):
169
# YOLOv5 TensorFlow saved_model export
170
keras_model = None
171
try:
172
import tensorflow as tf
173
from tensorflow import keras
174
175
from models.tf import TFDetect, TFModel
176
177
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
178
f = str(file).replace('.pt', '_saved_model')
179
batch_size, ch, *imgsz = list(im.shape) # BCHW
180
181
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
182
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
183
y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
184
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
185
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
186
keras_model = keras.Model(inputs=inputs, outputs=outputs)
187
keras_model.trainable = False
188
keras_model.summary()
189
keras_model.save(f, save_format='tf')
190
191
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
192
except Exception as e:
193
LOGGER.info(f'\n{prefix} export failure: {e}')
194
195
return keras_model
196
197
198
def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
199
# YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
200
try:
201
import tensorflow as tf
202
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
203
204
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
205
f = file.with_suffix('.pb')
206
207
m = tf.function(lambda x: keras_model(x)) # full model
208
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
209
frozen_func = convert_variables_to_constants_v2(m)
210
frozen_func.graph.as_graph_def()
211
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
212
213
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
214
except Exception as e:
215
LOGGER.info(f'\n{prefix} export failure: {e}')
216
217
218
def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')):
219
# YOLOv5 TensorFlow Lite export
220
try:
221
import tensorflow as tf
222
223
from models.tf import representative_dataset_gen
224
225
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
226
batch_size, ch, *imgsz = list(im.shape) # BCHW
227
f = str(file).replace('.pt', '-fp16.tflite')
228
229
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
230
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
231
converter.target_spec.supported_types = [tf.float16]
232
converter.optimizations = [tf.lite.Optimize.DEFAULT]
233
if int8:
234
dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data
235
converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib)
236
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
237
converter.target_spec.supported_types = []
238
converter.inference_input_type = tf.uint8 # or tf.int8
239
converter.inference_output_type = tf.uint8 # or tf.int8
240
converter.experimental_new_quantizer = False
241
f = str(file).replace('.pt', '-int8.tflite')
242
243
tflite_model = converter.convert()
244
open(f, "wb").write(tflite_model)
245
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
246
247
except Exception as e:
248
LOGGER.info(f'\n{prefix} export failure: {e}')
249
250
251
def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
252
# YOLOv5 TensorFlow.js export
253
try:
254
check_requirements(('tensorflowjs',))
255
import re
256
257
import tensorflowjs as tfjs
258
259
LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
260
f = str(file).replace('.pt', '_web_model') # js dir
261
f_pb = file.with_suffix('.pb') # *.pb path
262
f_json = f + '/model.json' # *.json path
263
264
cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \
265
f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}"
266
subprocess.run(cmd, shell=True)
267
268
json = open(f_json).read()
269
with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
270
subst = re.sub(
271
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
272
r'"Identity.?.?": {"name": "Identity.?.?"}, '
273
r'"Identity.?.?": {"name": "Identity.?.?"}, '
274
r'"Identity.?.?": {"name": "Identity.?.?"}}}',
275
r'{"outputs": {"Identity": {"name": "Identity"}, '
276
r'"Identity_1": {"name": "Identity_1"}, '
277
r'"Identity_2": {"name": "Identity_2"}, '
278
r'"Identity_3": {"name": "Identity_3"}}}',
279
json)
280
j.write(subst)
281
282
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
283
except Exception as e:
284
LOGGER.info(f'\n{prefix} export failure: {e}')
285
286
287
def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
288
try:
289
check_requirements(('tensorrt',))
290
import tensorrt as trt
291
292
opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
293
export_onnx(model, im, file, opset, train, False, simplify)
294
onnx = file.with_suffix('.onnx')
295
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
296
297
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
298
f = file.with_suffix('.engine') # TensorRT engine file
299
logger = trt.Logger(trt.Logger.INFO)
300
if verbose:
301
logger.min_severity = trt.Logger.Severity.VERBOSE
302
303
builder = trt.Builder(logger)
304
config = builder.create_builder_config()
305
config.max_workspace_size = workspace * 1 << 30
306
307
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
308
network = builder.create_network(flag)
309
parser = trt.OnnxParser(network, logger)
310
if not parser.parse_from_file(str(onnx)):
311
raise RuntimeError(f'failed to load ONNX file: {onnx}')
312
313
inputs = [network.get_input(i) for i in range(network.num_inputs)]
314
outputs = [network.get_output(i) for i in range(network.num_outputs)]
315
LOGGER.info(f'{prefix} Network Description:')
316
for inp in inputs:
317
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
318
for out in outputs:
319
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
320
321
half &= builder.platform_has_fast_fp16
322
LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}')
323
if half:
324
config.set_flag(trt.BuilderFlag.FP16)
325
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
326
t.write(engine.serialize())
327
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
328
329
except Exception as e:
330
LOGGER.info(f'\n{prefix} export failure: {e}')
331
332
333
@torch.no_grad()
334
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
335
weights=ROOT / 'yolov5s.pt', # weights path
336
imgsz=(640, 640), # image (height, width)
337
batch_size=1, # batch size
338
device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
339
include=('torchscript', 'onnx'), # include formats
340
half=False, # FP16 half-precision export
341
inplace=False, # set YOLOv5 Detect() inplace=True
342
train=False, # model.train() mode
343
optimize=False, # TorchScript: optimize for mobile
344
int8=False, # CoreML/TF INT8 quantization
345
dynamic=False, # ONNX/TF: dynamic axes
346
simplify=False, # ONNX: simplify model
347
opset=12, # ONNX: opset version
348
verbose=False, # TensorRT: verbose log
349
workspace=4, # TensorRT: workspace size (GB)
350
nms=False, # TF: add NMS to model
351
agnostic_nms=False, # TF: add agnostic NMS to model
352
topk_per_class=100, # TF.js NMS: topk per class to keep
353
topk_all=100, # TF.js NMS: topk for all classes to keep
354
iou_thres=0.45, # TF.js NMS: IoU threshold
355
conf_thres=0.25 # TF.js NMS: confidence threshold
356
):
357
t = time.time()
358
include = [x.lower() for x in include]
359
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports
360
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)
361
362
# Checks
363
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
364
opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12
365
366
# Load PyTorch model
367
device = select_device(device)
368
assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
369
model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model
370
nc, names = model.nc, model.names # number of classes, class names
371
372
# Input
373
gs = int(max(model.stride)) # grid size (max stride)
374
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
375
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
376
377
# Update model
378
if half:
379
im, model = im.half(), model.half() # to FP16
380
model.train() if train else model.eval() # training mode = no Detect() layer grid construction
381
for k, m in model.named_modules():
382
if isinstance(m, Conv): # assign export-friendly activations
383
if isinstance(m.act, nn.SiLU):
384
m.act = SiLU()
385
elif isinstance(m, Detect):
386
m.inplace = inplace
387
m.onnx_dynamic = dynamic
388
# m.forward = m.forward_export # assign forward (optional)
389
390
for _ in range(2):
391
y = model(im) # dry runs
392
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)")
393
394
# Exports
395
if 'torchscript' in include:
396
export_torchscript(model, im, file, optimize)
397
if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX
398
export_onnx(model, im, file, opset, train, dynamic, simplify)
399
if 'engine' in include:
400
export_engine(model, im, file, train, half, simplify, workspace, verbose)
401
if 'coreml' in include:
402
export_coreml(model, im, file)
403
if 'openvino' in include:
404
export_openvino(model, im, file)
405
406
# TensorFlow Exports
407
if any(tf_exports):
408
pb, tflite, tfjs = tf_exports[1:]
409
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
410
model = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,
411
agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class, topk_all=topk_all,
412
conf_thres=conf_thres, iou_thres=iou_thres) # keras model
413
if pb or tfjs: # pb prerequisite to tfjs
414
export_pb(model, im, file)
415
if tflite:
416
export_tflite(model, im, file, int8=int8, data=data, ncalib=100)
417
if tfjs:
418
export_tfjs(model, im, file)
419
420
# Finish
421
LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
422
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
423
f'\nVisualize with https://netron.app')
424
425
426
def parse_opt():
427
parser = argparse.ArgumentParser()
428
parser.add_argument('--data', type=str, default=ROOT / 'data/dotav15_poly.yaml', help='dataset.yaml path')
429
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'runs/train/yolov5m_finetune_dotav1.5/weights/best.pt', help='model.pt path(s)')
430
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[1024, 1024], help='image (h, w)')
431
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
432
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
433
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
434
parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
435
parser.add_argument('--train', action='store_true', help='model.train() mode')
436
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
437
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
438
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
439
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
440
parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
441
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
442
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
443
parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
444
parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
445
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
446
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
447
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
448
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
449
parser.add_argument('--include', nargs='+',
450
default=['torchscript', 'onnx'],
451
help='available formats are (torchscript, onnx, engine, coreml, saved_model, pb, tflite, tfjs)')
452
opt = parser.parse_args()
453
print_args(FILE.stem, opt)
454
return opt
455
456
457
def main(opt):
458
for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):
459
run(**vars(opt))
460
461
462
if __name__ == "__main__":
463
opt = parse_opt()
464
main(opt)
465
466