Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/examples/java-example/src/main/java/org/example/App.java
1890 views
1
package org.example;
2
3
import ai.onnxruntime.OrtException;
4
import javax.sound.sampled.*;
5
import java.io.File;
6
import java.io.IOException;
7
import java.util.ArrayList;
8
import java.util.HashMap;
9
import java.util.List;
10
import java.util.Map;
11
12
/**
13
* Silero VAD Java Example
14
* Voice Activity Detection using ONNX model
15
*
16
* @author VvvvvGH
17
*/
18
public class App {
19
20
// ONNX model path - using the model file from the project
21
private static final String MODEL_PATH = "../../src/silero_vad/data/silero_vad.onnx";
22
// Test audio file path
23
private static final String AUDIO_FILE_PATH = "../../en_example.wav";
24
// Sampling rate
25
private static final int SAMPLE_RATE = 16000;
26
// Speech threshold (consistent with Python default)
27
private static final float THRESHOLD = 0.5f;
28
// Negative threshold (used to determine speech end)
29
private static final float NEG_THRESHOLD = 0.35f; // threshold - 0.15
30
// Minimum speech duration (milliseconds)
31
private static final int MIN_SPEECH_DURATION_MS = 250;
32
// Minimum silence duration (milliseconds)
33
private static final int MIN_SILENCE_DURATION_MS = 100;
34
// Speech padding (milliseconds)
35
private static final int SPEECH_PAD_MS = 30;
36
// Window size (samples) - 512 samples for 16kHz
37
private static final int WINDOW_SIZE_SAMPLES = 512;
38
39
public static void main(String[] args) {
40
System.out.println("=".repeat(60));
41
System.out.println("Silero VAD Java ONNX Example");
42
System.out.println("=".repeat(60));
43
44
// Load ONNX model
45
SlieroVadOnnxModel model;
46
try {
47
System.out.println("Loading ONNX model: " + MODEL_PATH);
48
model = new SlieroVadOnnxModel(MODEL_PATH);
49
System.out.println("Model loaded successfully!");
50
} catch (OrtException e) {
51
System.err.println("Failed to load model: " + e.getMessage());
52
e.printStackTrace();
53
return;
54
}
55
56
// Read WAV file
57
float[] audioData;
58
try {
59
System.out.println("\nReading audio file: " + AUDIO_FILE_PATH);
60
audioData = readWavFileAsFloatArray(AUDIO_FILE_PATH);
61
System.out.println("Audio file read successfully, samples: " + audioData.length);
62
System.out.println("Audio duration: " + String.format("%.2f", (audioData.length / (float) SAMPLE_RATE)) + " seconds");
63
} catch (Exception e) {
64
System.err.println("Failed to read audio file: " + e.getMessage());
65
e.printStackTrace();
66
return;
67
}
68
69
// Get speech timestamps (batch mode, consistent with Python's get_speech_timestamps)
70
System.out.println("\nDetecting speech segments...");
71
List<Map<String, Integer>> speechTimestamps;
72
try {
73
speechTimestamps = getSpeechTimestamps(
74
audioData,
75
model,
76
THRESHOLD,
77
SAMPLE_RATE,
78
MIN_SPEECH_DURATION_MS,
79
MIN_SILENCE_DURATION_MS,
80
SPEECH_PAD_MS,
81
NEG_THRESHOLD
82
);
83
} catch (OrtException e) {
84
System.err.println("Failed to detect speech timestamps: " + e.getMessage());
85
e.printStackTrace();
86
return;
87
}
88
89
// Output detection results
90
System.out.println("\nDetected speech timestamps (in samples):");
91
for (Map<String, Integer> timestamp : speechTimestamps) {
92
System.out.println(timestamp);
93
}
94
95
// Output summary
96
System.out.println("\n" + "=".repeat(60));
97
System.out.println("Detection completed!");
98
System.out.println("Total detected " + speechTimestamps.size() + " speech segments");
99
System.out.println("=".repeat(60));
100
101
// Close model
102
try {
103
model.close();
104
} catch (OrtException e) {
105
System.err.println("Error closing model: " + e.getMessage());
106
}
107
}
108
109
/**
110
* Get speech timestamps
111
* Implements the same logic as Python's get_speech_timestamps
112
*
113
* @param audio Audio data (float array)
114
* @param model ONNX model
115
* @param threshold Speech threshold
116
* @param samplingRate Sampling rate
117
* @param minSpeechDurationMs Minimum speech duration (milliseconds)
118
* @param minSilenceDurationMs Minimum silence duration (milliseconds)
119
* @param speechPadMs Speech padding (milliseconds)
120
* @param negThreshold Negative threshold (used to determine speech end)
121
* @return List of speech timestamps
122
*/
123
private static List<Map<String, Integer>> getSpeechTimestamps(
124
float[] audio,
125
SlieroVadOnnxModel model,
126
float threshold,
127
int samplingRate,
128
int minSpeechDurationMs,
129
int minSilenceDurationMs,
130
int speechPadMs,
131
float negThreshold) throws OrtException {
132
133
// Reset model states
134
model.resetStates();
135
136
// Calculate parameters
137
int minSpeechSamples = samplingRate * minSpeechDurationMs / 1000;
138
int speechPadSamples = samplingRate * speechPadMs / 1000;
139
int minSilenceSamples = samplingRate * minSilenceDurationMs / 1000;
140
int windowSizeSamples = samplingRate == 16000 ? 512 : 256;
141
int audioLengthSamples = audio.length;
142
143
// Calculate speech probabilities for all audio chunks
144
List<Float> speechProbs = new ArrayList<>();
145
for (int currentStart = 0; currentStart < audioLengthSamples; currentStart += windowSizeSamples) {
146
float[] chunk = new float[windowSizeSamples];
147
int chunkLength = Math.min(windowSizeSamples, audioLengthSamples - currentStart);
148
System.arraycopy(audio, currentStart, chunk, 0, chunkLength);
149
150
// Pad with zeros if chunk is shorter than window size
151
if (chunkLength < windowSizeSamples) {
152
for (int i = chunkLength; i < windowSizeSamples; i++) {
153
chunk[i] = 0.0f;
154
}
155
}
156
157
float speechProb = model.call(new float[][]{chunk}, samplingRate)[0];
158
speechProbs.add(speechProb);
159
}
160
161
// Detect speech segments using the same algorithm as Python
162
boolean triggered = false;
163
List<Map<String, Integer>> speeches = new ArrayList<>();
164
Map<String, Integer> currentSpeech = null;
165
int tempEnd = 0;
166
167
for (int i = 0; i < speechProbs.size(); i++) {
168
float speechProb = speechProbs.get(i);
169
170
// Reset temporary end if speech probability exceeds threshold
171
if (speechProb >= threshold && tempEnd != 0) {
172
tempEnd = 0;
173
}
174
175
// Detect speech start
176
if (speechProb >= threshold && !triggered) {
177
triggered = true;
178
currentSpeech = new HashMap<>();
179
currentSpeech.put("start", windowSizeSamples * i);
180
continue;
181
}
182
183
// Detect speech end
184
if (speechProb < negThreshold && triggered) {
185
if (tempEnd == 0) {
186
tempEnd = windowSizeSamples * i;
187
}
188
if (windowSizeSamples * i - tempEnd < minSilenceSamples) {
189
continue;
190
} else {
191
currentSpeech.put("end", tempEnd);
192
if (currentSpeech.get("end") - currentSpeech.get("start") > minSpeechSamples) {
193
speeches.add(currentSpeech);
194
}
195
currentSpeech = null;
196
tempEnd = 0;
197
triggered = false;
198
}
199
}
200
}
201
202
// Handle the last speech segment
203
if (currentSpeech != null &&
204
(audioLengthSamples - currentSpeech.get("start")) > minSpeechSamples) {
205
currentSpeech.put("end", audioLengthSamples);
206
speeches.add(currentSpeech);
207
}
208
209
// Add speech padding - same logic as Python
210
for (int i = 0; i < speeches.size(); i++) {
211
Map<String, Integer> speech = speeches.get(i);
212
if (i == 0) {
213
speech.put("start", Math.max(0, speech.get("start") - speechPadSamples));
214
}
215
if (i != speeches.size() - 1) {
216
int silenceDuration = speeches.get(i + 1).get("start") - speech.get("end");
217
if (silenceDuration < 2 * speechPadSamples) {
218
speech.put("end", speech.get("end") + silenceDuration / 2);
219
speeches.get(i + 1).put("start",
220
Math.max(0, speeches.get(i + 1).get("start") - silenceDuration / 2));
221
} else {
222
speech.put("end", Math.min(audioLengthSamples, speech.get("end") + speechPadSamples));
223
speeches.get(i + 1).put("start",
224
Math.max(0, speeches.get(i + 1).get("start") - speechPadSamples));
225
}
226
} else {
227
speech.put("end", Math.min(audioLengthSamples, speech.get("end") + speechPadSamples));
228
}
229
}
230
231
return speeches;
232
}
233
234
/**
235
* Read WAV file and return as float array
236
*
237
* @param filePath WAV file path
238
* @return Audio data as float array (normalized to -1.0 to 1.0)
239
*/
240
private static float[] readWavFileAsFloatArray(String filePath)
241
throws UnsupportedAudioFileException, IOException {
242
File audioFile = new File(filePath);
243
AudioInputStream audioStream = AudioSystem.getAudioInputStream(audioFile);
244
245
// Get audio format information
246
AudioFormat format = audioStream.getFormat();
247
System.out.println("Audio format: " + format);
248
249
// Read all audio data
250
byte[] audioBytes = audioStream.readAllBytes();
251
audioStream.close();
252
253
// Convert to float array
254
float[] audioData = new float[audioBytes.length / 2];
255
for (int i = 0; i < audioData.length; i++) {
256
// 16-bit PCM: two bytes per sample (little-endian)
257
short sample = (short) ((audioBytes[i * 2] & 0xff) | (audioBytes[i * 2 + 1] << 8));
258
audioData[i] = sample / 32768.0f; // Normalize to -1.0 to 1.0
259
}
260
261
return audioData;
262
}
263
264
}
265
266