Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/DNN-OpenCV-Classification-with-Java/DnnOpenCV.java
3119 views
1
import org.opencv.core.Core;
2
import org.opencv.core.Mat;
3
import org.opencv.core.Point;
4
import org.opencv.core.Rect;
5
import org.opencv.core.Scalar;
6
import org.opencv.core.Size;
7
import org.opencv.dnn.Net;
8
import org.opencv.dnn.Dnn;
9
import org.opencv.imgproc.Imgproc;
10
import org.opencv.highgui.HighGui;
11
import org.opencv.imgcodecs.Imgcodecs;
12
13
import java.io.IOException;
14
import java.util.ArrayList;
15
import java.util.stream.Collectors;
16
import java.util.stream.Stream;
17
import java.nio.file.Files;
18
import java.nio.file.Paths;
19
20
import org.opencv.core.CvType;
21
22
23
public class DnnOpenCV {
24
private static final int TARGET_IMG_WIDTH = 224;
25
private static final int TARGET_IMG_HEIGHT = 224;
26
27
private static final double SCALE_FACTOR = 1 / 255.0;
28
29
private static final String IMAGENET_CLASSES = "imagenet_classes.txt";
30
private static final String MODEL_PATH = "models/pytorch_mobilenet.onnx";
31
32
private static final Scalar MEAN = new Scalar(0.485, 0.456, 0.406);
33
private static final Scalar STD = new Scalar(0.229, 0.224, 0.225);
34
35
private static Mat imageRead;
36
37
public static ArrayList<String> getImgLabels(String imgLabelsFilePath) throws IOException {
38
ArrayList<String> imgLabels;
39
try (Stream<String> lines = Files.lines(Paths.get(imgLabelsFilePath))) {
40
imgLabels = lines.collect(Collectors.toCollection(ArrayList::new));
41
}
42
return imgLabels;
43
}
44
45
public static Mat centerCrop(Mat inputImage) {
46
int y1 = Math.round((inputImage.rows() - TARGET_IMG_HEIGHT) / 2);
47
int y2 = Math.round(y1 + TARGET_IMG_HEIGHT);
48
int x1 = Math.round((inputImage.cols() - TARGET_IMG_WIDTH) / 2);
49
int x2 = Math.round(x1 + TARGET_IMG_WIDTH);
50
51
Rect centerRect = new Rect(x1, y1, (x2 - x1), (y2 - y1));
52
Mat croppedImage = new Mat(inputImage, centerRect);
53
54
return croppedImage;
55
}
56
57
public static Mat getPreprocessedImage(String imagePath) {
58
// get the image from the internal resource folder
59
imageRead = Imgcodecs.imread(imagePath);
60
61
// this object will store the preprocessed image
62
Mat image = new Mat();
63
64
// resize input image
65
Imgproc.resize(imageRead, image, new Size(256, 256));
66
67
// create empty Mat images for float conversions
68
Mat imgFloat = new Mat(image.rows(), image.cols(), CvType.CV_32FC3);
69
70
// convert input image to float type
71
image.convertTo(imgFloat, CvType.CV_32FC3, SCALE_FACTOR);
72
73
// crop input image
74
imgFloat = centerCrop(imgFloat);
75
76
// prepare DNN input
77
Mat blob = Dnn.blobFromImage(
78
imgFloat,
79
1.0, /* default scalefactor */
80
new Size(TARGET_IMG_WIDTH, TARGET_IMG_HEIGHT), /* target size */
81
MEAN, /* mean */
82
true, /* swapRB */
83
false /* crop */
84
);
85
86
// divide on std
87
Core.divide(blob, STD, blob);
88
89
return blob;
90
}
91
92
public static String getPredictedClass(Mat classificationResult) {
93
ArrayList<String> imgLabels = new ArrayList<String>();
94
try {
95
imgLabels = getImgLabels(IMAGENET_CLASSES);
96
} catch (IOException ex) {
97
System.out.printf("Could not read %s file:%n", IMAGENET_CLASSES);
98
ex.printStackTrace();
99
}
100
if (imgLabels.isEmpty()) {
101
return "";
102
}
103
// obtain max prediction result
104
Core.MinMaxLocResult mm = Core.minMaxLoc(classificationResult);
105
double maxValIndex = mm.maxLoc.x;
106
return imgLabels.get((int) maxValIndex);
107
}
108
109
public static void main(String[] args) {
110
String imageLocation = "images/coffee.jpg";
111
112
// load the OpenCV native library
113
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
114
115
// read and process the input image
116
Mat inputBlob = DnnOpenCV.getPreprocessedImage(imageLocation);
117
118
// read generated ONNX model into org.opencv.dnn.Net object
119
Net dnnNet = Dnn.readNetFromONNX(DnnOpenCV.MODEL_PATH);
120
System.out.println("DNN from ONNX was successfully loaded!");
121
122
// set OpenCV model input
123
dnnNet.setInput(inputBlob);
124
125
// provide inference
126
Mat classification = dnnNet.forward();
127
128
// decode classification results
129
String label = DnnOpenCV.getPredictedClass(classification);
130
System.out.println("Predicted Class: " + label);
131
132
// displaying the photo and putting the text on it
133
Point pos = new Point (50, 50);
134
Scalar colour = new Scalar(255, 255, 255);
135
Imgproc.putText(imageRead, "Predicted class is: " +label, pos, Imgproc.FONT_HERSHEY_SIMPLEX, 1.0, colour, 2);
136
HighGui.imshow("Input Image", imageRead);
137
if (HighGui.waitKey(0) == 27){
138
System.exit(0);
139
}
140
}
141
}
142
143