Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java
1898 views
1
package org.example;
2
3
import ai.onnxruntime.OnnxTensor;
4
import ai.onnxruntime.OrtEnvironment;
5
import ai.onnxruntime.OrtException;
6
import ai.onnxruntime.OrtSession;
7
import java.util.Arrays;
8
import java.util.HashMap;
9
import java.util.List;
10
import java.util.Map;
11
12
/**
13
* Silero VAD ONNX Model Wrapper
14
*
15
* @author VvvvvGH
16
*/
17
public class SlieroVadOnnxModel {
18
// ONNX runtime session
19
private final OrtSession session;
20
// Model state - dimensions: [2, batch_size, 128]
21
private float[][][] state;
22
// Context - stores the tail of the previous audio chunk
23
private float[][] context;
24
// Last sample rate
25
private int lastSr = 0;
26
// Last batch size
27
private int lastBatchSize = 0;
28
// Supported sample rates
29
private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);
30
31
// Constructor
32
public SlieroVadOnnxModel(String modelPath) throws OrtException {
33
// Get the ONNX runtime environment
34
OrtEnvironment env = OrtEnvironment.getEnvironment();
35
// Create ONNX session options
36
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
37
// Set InterOp thread count to 1 (for parallel processing of different graph operations)
38
opts.setInterOpNumThreads(1);
39
// Set IntraOp thread count to 1 (for parallel processing within a single operation)
40
opts.setIntraOpNumThreads(1);
41
// Enable CPU execution optimization
42
opts.addCPU(true);
43
// Create ONNX session with the environment, model path, and options
44
session = env.createSession(modelPath, opts);
45
// Reset states
46
resetStates();
47
}
48
49
/**
50
* Reset states with default batch size
51
*/
52
void resetStates() {
53
resetStates(1);
54
}
55
56
/**
57
* Reset states with specific batch size
58
*
59
* @param batchSize Batch size for state initialization
60
*/
61
void resetStates(int batchSize) {
62
state = new float[2][batchSize][128];
63
context = new float[0][]; // Empty context
64
lastSr = 0;
65
lastBatchSize = 0;
66
}
67
68
public void close() throws OrtException {
69
session.close();
70
}
71
72
/**
73
* Inner class for validation result
74
*/
75
public static class ValidationResult {
76
public final float[][] x;
77
public final int sr;
78
79
public ValidationResult(float[][] x, int sr) {
80
this.x = x;
81
this.sr = sr;
82
}
83
}
84
85
/**
86
* Validate input data
87
*
88
* @param x Audio data array
89
* @param sr Sample rate
90
* @return Validated input data and sample rate
91
*/
92
private ValidationResult validateInput(float[][] x, int sr) {
93
// Ensure input is at least 2D
94
if (x.length == 1) {
95
x = new float[][]{x[0]};
96
}
97
// Check if input dimension is valid
98
if (x.length > 2) {
99
throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
100
}
101
102
// Downsample if sample rate is a multiple of 16000
103
if (sr != 16000 && (sr % 16000 == 0)) {
104
int step = sr / 16000;
105
float[][] reducedX = new float[x.length][];
106
107
for (int i = 0; i < x.length; i++) {
108
float[] current = x[i];
109
float[] newArr = new float[(current.length + step - 1) / step];
110
111
for (int j = 0, index = 0; j < current.length; j += step, index++) {
112
newArr[index] = current[j];
113
}
114
115
reducedX[i] = newArr;
116
}
117
118
x = reducedX;
119
sr = 16000;
120
}
121
122
// Validate sample rate
123
if (!SAMPLE_RATES.contains(sr)) {
124
throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
125
}
126
127
// Check if audio chunk is too short
128
if (((float) sr) / x[0].length > 31.25) {
129
throw new IllegalArgumentException("Input audio is too short");
130
}
131
132
return new ValidationResult(x, sr);
133
}
134
135
/**
136
* Call the ONNX model for inference
137
*
138
* @param x Audio data array
139
* @param sr Sample rate
140
* @return Speech probability output
141
* @throws OrtException If ONNX runtime error occurs
142
*/
143
public float[] call(float[][] x, int sr) throws OrtException {
144
ValidationResult result = validateInput(x, sr);
145
x = result.x;
146
sr = result.sr;
147
148
int batchSize = x.length;
149
int numSamples = sr == 16000 ? 512 : 256;
150
int contextSize = sr == 16000 ? 64 : 32;
151
152
// Reset states only when sample rate or batch size changes
153
if (lastSr != 0 && lastSr != sr) {
154
resetStates(batchSize);
155
} else if (lastBatchSize != 0 && lastBatchSize != batchSize) {
156
resetStates(batchSize);
157
} else if (lastBatchSize == 0) {
158
// First call - state is already initialized, just set batch size
159
lastBatchSize = batchSize;
160
}
161
162
// Initialize context if needed
163
if (context.length == 0) {
164
context = new float[batchSize][contextSize];
165
}
166
167
// Concatenate context and input
168
float[][] xWithContext = new float[batchSize][contextSize + numSamples];
169
for (int i = 0; i < batchSize; i++) {
170
// Copy context
171
System.arraycopy(context[i], 0, xWithContext[i], 0, contextSize);
172
// Copy input
173
System.arraycopy(x[i], 0, xWithContext[i], contextSize, numSamples);
174
}
175
176
OrtEnvironment env = OrtEnvironment.getEnvironment();
177
178
OnnxTensor inputTensor = null;
179
OnnxTensor stateTensor = null;
180
OnnxTensor srTensor = null;
181
OrtSession.Result ortOutputs = null;
182
183
try {
184
// Create input tensors
185
inputTensor = OnnxTensor.createTensor(env, xWithContext);
186
stateTensor = OnnxTensor.createTensor(env, state);
187
srTensor = OnnxTensor.createTensor(env, new long[]{sr});
188
189
Map<String, OnnxTensor> inputs = new HashMap<>();
190
inputs.put("input", inputTensor);
191
inputs.put("sr", srTensor);
192
inputs.put("state", stateTensor);
193
194
// Run ONNX model inference
195
ortOutputs = session.run(inputs);
196
// Get output results
197
float[][] output = (float[][]) ortOutputs.get(0).getValue();
198
state = (float[][][]) ortOutputs.get(1).getValue();
199
200
// Update context - save the last contextSize samples from input
201
for (int i = 0; i < batchSize; i++) {
202
System.arraycopy(xWithContext[i], xWithContext[i].length - contextSize,
203
context[i], 0, contextSize);
204
}
205
206
lastSr = sr;
207
lastBatchSize = batchSize;
208
return output[0];
209
} finally {
210
if (inputTensor != null) {
211
inputTensor.close();
212
}
213
if (stateTensor != null) {
214
stateTensor.close();
215
}
216
if (srTensor != null) {
217
srTensor.close();
218
}
219
if (ortOutputs != null) {
220
ortOutputs.close();
221
}
222
}
223
}
224
}
225
226