Path: blob/master/DNN-OpenCV-Classification-with-Java/DnnOpenCV.java
3119 views
import org.opencv.core.Core;1import org.opencv.core.Mat;2import org.opencv.core.Point;3import org.opencv.core.Rect;4import org.opencv.core.Scalar;5import org.opencv.core.Size;6import org.opencv.dnn.Net;7import org.opencv.dnn.Dnn;8import org.opencv.imgproc.Imgproc;9import org.opencv.highgui.HighGui;10import org.opencv.imgcodecs.Imgcodecs;1112import java.io.IOException;13import java.util.ArrayList;14import java.util.stream.Collectors;15import java.util.stream.Stream;16import java.nio.file.Files;17import java.nio.file.Paths;1819import org.opencv.core.CvType;202122public class DnnOpenCV {23private static final int TARGET_IMG_WIDTH = 224;24private static final int TARGET_IMG_HEIGHT = 224;2526private static final double SCALE_FACTOR = 1 / 255.0;2728private static final String IMAGENET_CLASSES = "imagenet_classes.txt";29private static final String MODEL_PATH = "models/pytorch_mobilenet.onnx";3031private static final Scalar MEAN = new Scalar(0.485, 0.456, 0.406);32private static final Scalar STD = new Scalar(0.229, 0.224, 0.225);3334private static Mat imageRead;3536public static ArrayList<String> getImgLabels(String imgLabelsFilePath) throws IOException {37ArrayList<String> imgLabels;38try (Stream<String> lines = Files.lines(Paths.get(imgLabelsFilePath))) {39imgLabels = lines.collect(Collectors.toCollection(ArrayList::new));40}41return imgLabels;42}4344public static Mat centerCrop(Mat inputImage) {45int y1 = Math.round((inputImage.rows() - TARGET_IMG_HEIGHT) / 2);46int y2 = Math.round(y1 + TARGET_IMG_HEIGHT);47int x1 = Math.round((inputImage.cols() - TARGET_IMG_WIDTH) / 2);48int x2 = Math.round(x1 + TARGET_IMG_WIDTH);4950Rect centerRect = new Rect(x1, y1, (x2 - x1), (y2 - y1));51Mat croppedImage = new Mat(inputImage, centerRect);5253return croppedImage;54}5556public static Mat getPreprocessedImage(String imagePath) {57// get the image from the internal resource folder58imageRead = Imgcodecs.imread(imagePath);5960// this object will store the preprocessed image61Mat image = new Mat();6263// resize input image64Imgproc.resize(imageRead, image, new Size(256, 256));6566// create empty Mat images for float conversions67Mat imgFloat = new Mat(image.rows(), image.cols(), CvType.CV_32FC3);6869// convert input image to float type70image.convertTo(imgFloat, CvType.CV_32FC3, SCALE_FACTOR);7172// crop input image73imgFloat = centerCrop(imgFloat);7475// prepare DNN input76Mat blob = Dnn.blobFromImage(77imgFloat,781.0, /* default scalefactor */79new Size(TARGET_IMG_WIDTH, TARGET_IMG_HEIGHT), /* target size */80MEAN, /* mean */81true, /* swapRB */82false /* crop */83);8485// divide on std86Core.divide(blob, STD, blob);8788return blob;89}9091public static String getPredictedClass(Mat classificationResult) {92ArrayList<String> imgLabels = new ArrayList<String>();93try {94imgLabels = getImgLabels(IMAGENET_CLASSES);95} catch (IOException ex) {96System.out.printf("Could not read %s file:%n", IMAGENET_CLASSES);97ex.printStackTrace();98}99if (imgLabels.isEmpty()) {100return "";101}102// obtain max prediction result103Core.MinMaxLocResult mm = Core.minMaxLoc(classificationResult);104double maxValIndex = mm.maxLoc.x;105return imgLabels.get((int) maxValIndex);106}107108public static void main(String[] args) {109String imageLocation = "images/coffee.jpg";110111// load the OpenCV native library112System.loadLibrary(Core.NATIVE_LIBRARY_NAME);113114// read and process the input image115Mat inputBlob = DnnOpenCV.getPreprocessedImage(imageLocation);116117// read generated ONNX model into org.opencv.dnn.Net object118Net dnnNet = Dnn.readNetFromONNX(DnnOpenCV.MODEL_PATH);119System.out.println("DNN from ONNX was successfully loaded!");120121// set OpenCV model input122dnnNet.setInput(inputBlob);123124// provide inference125Mat classification = dnnNet.forward();126127// decode classification results128String label = DnnOpenCV.getPredictedClass(classification);129System.out.println("Predicted Class: " + label);130131// displaying the photo and putting the text on it132Point pos = new Point (50, 50);133Scalar colour = new Scalar(255, 255, 255);134Imgproc.putText(imageRead, "Predicted class is: " +label, pos, Imgproc.FONT_HERSHEY_SIMPLEX, 1.0, colour, 2);135HighGui.imshow("Input Image", imageRead);136if (HighGui.waitKey(0) == 27){137System.exit(0);138}139}140}141142143