Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/DNN-OpenCV-Classification-Android/Mobilenetv2ToOnnx.py
3118 views
1
import argparse
2
import os
3
4
import cv2
5
import numpy as np
6
import onnx
7
import onnxruntime
8
import torch
9
from albumentations import (
10
CenterCrop,
11
Compose,
12
Normalize,
13
Resize,
14
)
15
from torchvision import models
16
17
18
def compare_pytorch_onnx(
19
original_model_preds, onnx_model_path, input_image,
20
):
21
# get onnx result
22
session = onnxruntime.InferenceSession(onnx_model_path)
23
input_name = session.get_inputs()[0].name
24
onnx_result = session.run([], {input_name: input_image})
25
onnx_result = np.squeeze(onnx_result, axis=0)
26
27
print("Checking PyTorch model and converted ONNX model outputs ... ")
28
for test_onnx_result, gold_result in zip(onnx_result, original_model_preds):
29
np.testing.assert_almost_equal(
30
gold_result, test_onnx_result, decimal=3,
31
)
32
print("PyTorch and ONNX output values are equal! \n")
33
34
35
def get_onnx_model(
36
original_model, input_image, model_path="models", model_name="pytorch_mobilenet",
37
):
38
# create model root dir
39
os.makedirs(model_path, exist_ok=True)
40
41
model_name = os.path.join(model_path, model_name + ".onnx")
42
43
torch.onnx.export(
44
original_model, torch.Tensor(input_image), model_name, verbose=True,
45
)
46
print("ONNX model was successfully generated: {} \n".format(model_name))
47
48
return model_name
49
50
51
def get_preprocessed_image(image_name):
52
# read image
53
original_image = cv2.imread(image_name)
54
55
# convert original image to RGB format
56
image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
57
58
# transform input image:
59
# 1. resize the image
60
# 2. crop the image
61
# 3. normalize: subtract mean and divide by standard deviation
62
transform = Compose(
63
[
64
Resize(height=256, width=256),
65
CenterCrop(224, 224),
66
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
67
],
68
)
69
image = transform(image=image)["image"]
70
71
# change the order of channels
72
image = image.transpose(2, 0, 1)
73
return np.expand_dims(image, axis=0)
74
75
76
def get_predicted_class(pytorch_preds):
77
# read ImageNet class id to name mapping
78
with open("imagenet_classes.txt") as f:
79
labels = [line.strip() for line in f.readlines()]
80
81
# find the class with the maximum score
82
pytorch_class_idx = np.argmax(pytorch_preds, axis=1)
83
predicted_pytorch_label = labels[pytorch_class_idx[0]]
84
85
# print top predicted class
86
print("Predicted class by PyTorch model: ", predicted_pytorch_label)
87
88
89
def get_execution_arguments():
90
parser = argparse.ArgumentParser()
91
parser.add_argument(
92
"--input_image",
93
type=str,
94
help="Define the full input image path, including its name",
95
default="test_img_cup.jpg",
96
)
97
return parser.parse_args()
98
99
100
if __name__ == "__main__":
101
# get the test case parameters
102
args = get_execution_arguments()
103
104
# read and process the input image
105
image = get_preprocessed_image(image_name=args.input_image)
106
107
# obtain original model
108
pytorch_model = models.mobilenet_v2(pretrained=True)
109
110
# provide inference of the original PyTorch model
111
pytorch_model.eval()
112
pytorch_predictions = pytorch_model(torch.Tensor(image)).detach().numpy()
113
114
# obtain OpenCV generated ONNX model
115
onnx_model_path = get_onnx_model(original_model=pytorch_model, input_image=image)
116
117
# check if conversion succeeded
118
onnx_model = onnx.load(onnx_model_path)
119
onnx.checker.check_model(onnx_model)
120
121
# check onnx model output
122
compare_pytorch_onnx(
123
pytorch_predictions, onnx_model_path, image,
124
)
125
126
# decode classification results
127
get_predicted_class(pytorch_preds=pytorch_predictions)
128
129