CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
hukaixuan19970627

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: hukaixuan19970627/yolov5_obb
Path: blob/master/detect.py
Views: 475
1
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
"""
3
Run inference on images, videos, directories, streams, etc.
4
5
Usage:
6
$ python path/to/detect.py --weights yolov5s.pt --source 0 # webcam
7
img.jpg # image
8
vid.mp4 # video
9
path/ # directory
10
path/*.jpg # glob
11
'https://youtu.be/Zgi9g1ksQHc' # YouTube
12
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
13
"""
14
15
import argparse
16
import os
17
import sys
18
from pathlib import Path
19
20
import cv2
21
import torch
22
import torch.backends.cudnn as cudnn
23
24
FILE = Path(__file__).resolve()
25
ROOT = FILE.parents[0] # YOLOv5 root directory
26
if str(ROOT) not in sys.path:
27
sys.path.append(str(ROOT)) # add ROOT to PATH
28
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
29
30
from models.common import DetectMultiBackend
31
from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
32
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr,
33
increment_path, non_max_suppression, non_max_suppression_obb, print_args, scale_coords, scale_polys, strip_optimizer, xyxy2xywh)
34
from utils.plots import Annotator, colors, save_one_box
35
from utils.torch_utils import select_device, time_sync
36
from utils.rboxs_utils import poly2rbox, rbox2poly
37
38
39
@torch.no_grad()
40
def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
41
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
42
imgsz=(640, 640), # inference size (height, width)
43
conf_thres=0.25, # confidence threshold
44
iou_thres=0.45, # NMS IOU threshold
45
max_det=1000, # maximum detections per image
46
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
47
view_img=False, # show results
48
save_txt=False, # save results to *.txt
49
save_conf=False, # save confidences in --save-txt labels
50
save_crop=False, # save cropped prediction boxes
51
nosave=False, # do not save images/videos
52
classes=None, # filter by class: --class 0, or --class 0 2 3
53
agnostic_nms=False, # class-agnostic NMS
54
augment=False, # augmented inference
55
visualize=False, # visualize features
56
update=False, # update all models
57
project=ROOT / 'runs/detect', # save results to project/name
58
name='exp', # save results to project/name
59
exist_ok=False, # existing project/name ok, do not increment
60
line_thickness=3, # bounding box thickness (pixels)
61
hide_labels=False, # hide labels
62
hide_conf=False, # hide confidences
63
half=False, # use FP16 half-precision inference
64
dnn=False, # use OpenCV DNN for ONNX inference
65
):
66
source = str(source)
67
save_img = not nosave and not source.endswith('.txt') # save inference images
68
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
69
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
70
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
71
if is_url and is_file:
72
source = check_file(source) # download
73
74
# Directories
75
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
76
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
77
78
# Load model
79
device = select_device(device)
80
model = DetectMultiBackend(weights, device=device, dnn=dnn)
81
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
82
imgsz = check_img_size(imgsz, s=stride) # check image size
83
84
# Half
85
half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
86
if pt or jit:
87
model.model.half() if half else model.model.float()
88
89
# Dataloader
90
if webcam:
91
view_img = check_imshow()
92
cudnn.benchmark = True # set True to speed up constant image size inference
93
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
94
bs = len(dataset) # batch_size
95
else:
96
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
97
bs = 1 # batch_size
98
vid_path, vid_writer = [None] * bs, [None] * bs
99
100
# Run inference
101
model.warmup(imgsz=(1, 3, *imgsz), half=half) # warmup
102
dt, seen = [0.0, 0.0, 0.0], 0
103
for path, im, im0s, vid_cap, s in dataset:
104
t1 = time_sync()
105
im = torch.from_numpy(im).to(device)
106
im = im.half() if half else im.float() # uint8 to fp16/32
107
im /= 255 # 0 - 255 to 0.0 - 1.0
108
if len(im.shape) == 3:
109
im = im[None] # expand for batch dim
110
t2 = time_sync()
111
dt[0] += t2 - t1
112
113
# Inference
114
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
115
pred = model(im, augment=augment, visualize=visualize)
116
t3 = time_sync()
117
dt[1] += t3 - t2
118
119
# NMS
120
# pred: list*(n, [xylsθ, conf, cls]) θ ∈ [-pi/2, pi/2)
121
pred = non_max_suppression_obb(pred, conf_thres, iou_thres, classes, agnostic_nms, multi_label=True, max_det=max_det)
122
dt[2] += time_sync() - t3
123
124
# Second-stage classifier (optional)
125
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
126
127
# Process predictions
128
for i, det in enumerate(pred): # per image
129
pred_poly = rbox2poly(det[:, :5]) # (n, [x1 y1 x2 y2 x3 y3 x4 y4])
130
seen += 1
131
if webcam: # batch_size >= 1
132
p, im0, frame = path[i], im0s[i].copy(), dataset.count
133
s += f'{i}: '
134
else:
135
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
136
137
p = Path(p) # to Path
138
save_path = str(save_dir / p.name) # im.jpg
139
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
140
s += '%gx%g ' % im.shape[2:] # print string
141
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
142
imc = im0.copy() if save_crop else im0 # for save_crop
143
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
144
if len(det):
145
# Rescale polys from img_size to im0 size
146
# det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
147
pred_poly = scale_polys(im.shape[2:], pred_poly, im0.shape)
148
det = torch.cat((pred_poly, det[:, -2:]), dim=1) # (n, [poly conf cls])
149
150
# Print results
151
for c in det[:, -1].unique():
152
n = (det[:, -1] == c).sum() # detections per class
153
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
154
155
# Write results
156
for *poly, conf, cls in reversed(det):
157
if save_txt: # Write to file
158
# xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
159
poly = poly.tolist()
160
line = (cls, *poly, conf) if save_conf else (cls, *poly) # label format
161
with open(txt_path + '.txt', 'a') as f:
162
f.write(('%g ' * len(line)).rstrip() % line + '\n')
163
164
if save_img or save_crop or view_img: # Add poly to image
165
c = int(cls) # integer class
166
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
167
# annotator.box_label(xyxy, label, color=colors(c, True))
168
annotator.poly_label(poly, label, color=colors(c, True))
169
if save_crop: # Yolov5-obb doesn't support it yet
170
# save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
171
pass
172
173
# Print time (inference-only)
174
LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
175
176
# Stream results
177
im0 = annotator.result()
178
if view_img:
179
cv2.imshow(str(p), im0)
180
cv2.waitKey(1) # 1 millisecond
181
182
# Save results (image with detections)
183
if save_img:
184
if dataset.mode == 'image':
185
cv2.imwrite(save_path, im0)
186
else: # 'video' or 'stream'
187
if vid_path[i] != save_path: # new video
188
vid_path[i] = save_path
189
if isinstance(vid_writer[i], cv2.VideoWriter):
190
vid_writer[i].release() # release previous video writer
191
if vid_cap: # video
192
fps = vid_cap.get(cv2.CAP_PROP_FPS)
193
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
194
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
195
else: # stream
196
fps, w, h = 30, im0.shape[1], im0.shape[0]
197
save_path += '.mp4'
198
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
199
vid_writer[i].write(im0)
200
201
# Print results
202
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
203
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
204
if save_txt or save_img:
205
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
206
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
207
if update:
208
strip_optimizer(weights) # update model (to fix SourceChangeWarning)
209
210
211
def parse_opt():
212
parser = argparse.ArgumentParser()
213
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'runs/train/yolov5n_DroneVehicle/weights/best.pt', help='model path(s)')
214
parser.add_argument('--source', type=str, default='/media/test/4d846cae-2315-4928-8d1b-ca6d3a61a3c6/DroneVehicle/val/raw/images/', help='file/dir/URL/glob, 0 for webcam')
215
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[840], help='inference size h,w')
216
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
217
parser.add_argument('--iou-thres', type=float, default=0.2, help='NMS IoU threshold')
218
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
219
parser.add_argument('--device', default='3', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
220
parser.add_argument('--view-img', action='store_true', help='show results')
221
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
222
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
223
parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
224
parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
225
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
226
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
227
parser.add_argument('--augment', action='store_true', help='augmented inference')
228
parser.add_argument('--visualize', action='store_true', help='visualize features')
229
parser.add_argument('--update', action='store_true', help='update all models')
230
parser.add_argument('--project', default='runs/detect', help='save results to project/name')
231
parser.add_argument('--name', default='exp', help='save results to project/name')
232
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
233
parser.add_argument('--line-thickness', default=2, type=int, help='bounding box thickness (pixels)')
234
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
235
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
236
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
237
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
238
opt = parser.parse_args()
239
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
240
print_args(FILE.stem, opt)
241
return opt
242
243
244
def main(opt):
245
check_requirements(exclude=('tensorboard', 'thop'))
246
run(**vars(opt))
247
248
249
if __name__ == "__main__":
250
opt = parse_opt()
251
main(opt)
252
253