Path: blob/master/DNN-OpenCV-Classification-Android/Mobilenetv2ToOnnx.py
3118 views
import argparse1import os23import cv24import numpy as np5import onnx6import onnxruntime7import torch8from albumentations import (9CenterCrop,10Compose,11Normalize,12Resize,13)14from torchvision import models151617def compare_pytorch_onnx(18original_model_preds, onnx_model_path, input_image,19):20# get onnx result21session = onnxruntime.InferenceSession(onnx_model_path)22input_name = session.get_inputs()[0].name23onnx_result = session.run([], {input_name: input_image})24onnx_result = np.squeeze(onnx_result, axis=0)2526print("Checking PyTorch model and converted ONNX model outputs ... ")27for test_onnx_result, gold_result in zip(onnx_result, original_model_preds):28np.testing.assert_almost_equal(29gold_result, test_onnx_result, decimal=3,30)31print("PyTorch and ONNX output values are equal! \n")323334def get_onnx_model(35original_model, input_image, model_path="models", model_name="pytorch_mobilenet",36):37# create model root dir38os.makedirs(model_path, exist_ok=True)3940model_name = os.path.join(model_path, model_name + ".onnx")4142torch.onnx.export(43original_model, torch.Tensor(input_image), model_name, verbose=True,44)45print("ONNX model was successfully generated: {} \n".format(model_name))4647return model_name484950def get_preprocessed_image(image_name):51# read image52original_image = cv2.imread(image_name)5354# convert original image to RGB format55image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)5657# transform input image:58# 1. resize the image59# 2. crop the image60# 3. normalize: subtract mean and divide by standard deviation61transform = Compose(62[63Resize(height=256, width=256),64CenterCrop(224, 224),65Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),66],67)68image = transform(image=image)["image"]6970# change the order of channels71image = image.transpose(2, 0, 1)72return np.expand_dims(image, axis=0)737475def get_predicted_class(pytorch_preds):76# read ImageNet class id to name mapping77with open("imagenet_classes.txt") as f:78labels = [line.strip() for line in f.readlines()]7980# find the class with the maximum score81pytorch_class_idx = np.argmax(pytorch_preds, axis=1)82predicted_pytorch_label = labels[pytorch_class_idx[0]]8384# print top predicted class85print("Predicted class by PyTorch model: ", predicted_pytorch_label)868788def get_execution_arguments():89parser = argparse.ArgumentParser()90parser.add_argument(91"--input_image",92type=str,93help="Define the full input image path, including its name",94default="test_img_cup.jpg",95)96return parser.parse_args()979899if __name__ == "__main__":100# get the test case parameters101args = get_execution_arguments()102103# read and process the input image104image = get_preprocessed_image(image_name=args.input_image)105106# obtain original model107pytorch_model = models.mobilenet_v2(pretrained=True)108109# provide inference of the original PyTorch model110pytorch_model.eval()111pytorch_predictions = pytorch_model(torch.Tensor(image)).detach().numpy()112113# obtain OpenCV generated ONNX model114onnx_model_path = get_onnx_model(original_model=pytorch_model, input_image=image)115116# check if conversion succeeded117onnx_model = onnx.load(onnx_model_path)118onnx.checker.check_model(onnx_model)119120# check onnx model output121compare_pytorch_onnx(122pytorch_predictions, onnx_model_path, image,123)124125# decode classification results126get_predicted_class(pytorch_preds=pytorch_predictions)127128129