Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/examples/java-wav-file-example/src/main/java/org/example/SileroVadOnnxModel.java
1171 views
1
package org.example;
2
3
import ai.onnxruntime.OnnxTensor;
4
import ai.onnxruntime.OrtEnvironment;
5
import ai.onnxruntime.OrtException;
6
import ai.onnxruntime.OrtSession;
7
import java.util.Arrays;
8
import java.util.HashMap;
9
import java.util.List;
10
import java.util.Map;
11
12
public class SileroVadOnnxModel {
13
// Define private variable OrtSession
14
private final OrtSession session;
15
private float[][][] state;
16
private float[][] context;
17
// Define the last sample rate
18
private int lastSr = 0;
19
// Define the last batch size
20
private int lastBatchSize = 0;
21
// Define a list of supported sample rates
22
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
23
24
// Constructor
25
public SileroVadOnnxModel(String modelPath) throws OrtException {
26
// Get the ONNX runtime environment
27
OrtEnvironment env = OrtEnvironment.getEnvironment();
28
// Create an ONNX session options object
29
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
30
// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
31
opts.setInterOpNumThreads(1);
32
// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
33
opts.setIntraOpNumThreads(1);
34
// Add a CPU device, setting to false disables CPU execution optimization
35
opts.addCPU(true);
36
// Create an ONNX session using the environment, model path, and options
37
session = env.createSession(modelPath, opts);
38
// Reset states
39
resetStates();
40
}
41
42
/**
43
* Reset states
44
*/
45
void resetStates() {
46
state = new float[2][1][128];
47
context = new float[0][];
48
lastSr = 0;
49
lastBatchSize = 0;
50
}
51
52
public void close() throws OrtException {
53
session.close();
54
}
55
56
/**
57
* Define inner class ValidationResult
58
*/
59
public static class ValidationResult {
60
public final float[][] x;
61
public final int sr;
62
63
// Constructor
64
public ValidationResult(float[][] x, int sr) {
65
this.x = x;
66
this.sr = sr;
67
}
68
}
69
70
/**
71
* Function to validate input data
72
*/
73
private ValidationResult validateInput(float[][] x, int sr) {
74
// Process the input data with dimension 1
75
if (x.length == 1) {
76
x = new float[][]{x[0]};
77
}
78
// Throw an exception when the input data dimension is greater than 2
79
if (x.length > 2) {
80
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
81
}
82
83
// Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
84
if (sr != 16000 && (sr % 16000 == 0)) {
85
int step = sr / 16000;
86
float[][] reducedX = new float[x.length][];
87
88
for (int i = 0; i < x.length; i++) {
89
float[] current = x[i];
90
float[] newArr = new float[(current.length + step - 1) / step];
91
92
for (int j = 0, index = 0; j < current.length; j += step, index++) {
93
newArr[index] = current[j];
94
}
95
96
reducedX[i] = newArr;
97
}
98
99
x = reducedX;
100
sr = 16000;
101
}
102
103
// If the sample rate is not in the list of supported sample rates, throw an exception
104
if (!SAMPLE_RATES.contains(sr)) {
105
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
106
}
107
108
// If the input audio block is too short, throw an exception
109
if (((float) sr) / x[0].length > 31.25) {
110
throw new IllegalArgumentException("Input audio is too short");
111
}
112
113
// Return the validated result
114
return new ValidationResult(x, sr);
115
}
116
117
private static float[][] concatenate(float[][] a, float[][] b) {
118
if (a.length != b.length) {
119
throw new IllegalArgumentException("The number of rows in both arrays must be the same.");
120
}
121
122
int rows = a.length;
123
int colsA = a[0].length;
124
int colsB = b[0].length;
125
float[][] result = new float[rows][colsA + colsB];
126
127
for (int i = 0; i < rows; i++) {
128
System.arraycopy(a[i], 0, result[i], 0, colsA);
129
System.arraycopy(b[i], 0, result[i], colsA, colsB);
130
}
131
132
return result;
133
}
134
135
private static float[][] getLastColumns(float[][] array, int contextSize) {
136
int rows = array.length;
137
int cols = array[0].length;
138
139
if (contextSize > cols) {
140
throw new IllegalArgumentException("contextSize cannot be greater than the number of columns in the array.");
141
}
142
143
float[][] result = new float[rows][contextSize];
144
145
for (int i = 0; i < rows; i++) {
146
System.arraycopy(array[i], cols - contextSize, result[i], 0, contextSize);
147
}
148
149
return result;
150
}
151
152
/**
153
* Method to call the ONNX model
154
*/
155
public float[] call(float[][] x, int sr) throws OrtException {
156
ValidationResult result = validateInput(x, sr);
157
x = result.x;
158
sr = result.sr;
159
int numberSamples = 256;
160
if (sr == 16000) {
161
numberSamples = 512;
162
}
163
164
if (x[0].length != numberSamples) {
165
throw new IllegalArgumentException("Provided number of samples is " + x[0].length + " (Supported values: 256 for 8000 sample rate, 512 for 16000)");
166
}
167
168
int batchSize = x.length;
169
170
int contextSize = 32;
171
if (sr == 16000) {
172
contextSize = 64;
173
}
174
175
if (lastBatchSize == 0) {
176
resetStates();
177
}
178
if (lastSr != 0 && lastSr != sr) {
179
resetStates();
180
}
181
if (lastBatchSize != 0 && lastBatchSize != batchSize) {
182
resetStates();
183
}
184
185
if (context.length == 0) {
186
context = new float[batchSize][contextSize];
187
}
188
189
x = concatenate(context, x);
190
191
OrtEnvironment env = OrtEnvironment.getEnvironment();
192
193
OnnxTensor inputTensor = null;
194
OnnxTensor stateTensor = null;
195
OnnxTensor srTensor = null;
196
OrtSession.Result ortOutputs = null;
197
198
try {
199
// Create input tensors
200
inputTensor = OnnxTensor.createTensor(env, x);
201
stateTensor = OnnxTensor.createTensor(env, state);
202
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
203
204
Map<String, OnnxTensor> inputs = new HashMap<>();
205
inputs.put("input", inputTensor);
206
inputs.put("sr", srTensor);
207
inputs.put("state", stateTensor);
208
209
// Call the ONNX model for calculation
210
ortOutputs = session.run(inputs);
211
// Get the output results
212
float[][] output = (float[][]) ortOutputs.get(0).getValue();
213
state = (float[][][]) ortOutputs.get(1).getValue();
214
215
context = getLastColumns(x, contextSize);
216
lastSr = sr;
217
lastBatchSize = batchSize;
218
return output[0];
219
} finally {
220
if (inputTensor != null) {
221
inputTensor.close();
222
}
223
if (stateTensor != null) {
224
stateTensor.close();
225
}
226
if (srTensor != null) {
227
srTensor.close();
228
}
229
if (ortOutputs != null) {
230
ortOutputs.close();
231
}
232
}
233
}
234
}
235
236