Path: blob/master/examples/java-example/src/main/java/org/example/App.java
1890 views
package org.example;12import ai.onnxruntime.OrtException;3import javax.sound.sampled.*;4import java.io.File;5import java.io.IOException;6import java.util.ArrayList;7import java.util.HashMap;8import java.util.List;9import java.util.Map;1011/**12* Silero VAD Java Example13* Voice Activity Detection using ONNX model14*15* @author VvvvvGH16*/17public class App {1819// ONNX model path - using the model file from the project20private static final String MODEL_PATH = "../../src/silero_vad/data/silero_vad.onnx";21// Test audio file path22private static final String AUDIO_FILE_PATH = "../../en_example.wav";23// Sampling rate24private static final int SAMPLE_RATE = 16000;25// Speech threshold (consistent with Python default)26private static final float THRESHOLD = 0.5f;27// Negative threshold (used to determine speech end)28private static final float NEG_THRESHOLD = 0.35f; // threshold - 0.1529// Minimum speech duration (milliseconds)30private static final int MIN_SPEECH_DURATION_MS = 250;31// Minimum silence duration (milliseconds)32private static final int MIN_SILENCE_DURATION_MS = 100;33// Speech padding (milliseconds)34private static final int SPEECH_PAD_MS = 30;35// Window size (samples) - 512 samples for 16kHz36private static final int WINDOW_SIZE_SAMPLES = 512;3738public static void main(String[] args) {39System.out.println("=".repeat(60));40System.out.println("Silero VAD Java ONNX Example");41System.out.println("=".repeat(60));4243// Load ONNX model44SlieroVadOnnxModel model;45try {46System.out.println("Loading ONNX model: " + MODEL_PATH);47model = new SlieroVadOnnxModel(MODEL_PATH);48System.out.println("Model loaded successfully!");49} catch (OrtException e) {50System.err.println("Failed to load model: " + e.getMessage());51e.printStackTrace();52return;53}5455// Read WAV file56float[] audioData;57try {58System.out.println("\nReading audio file: " + AUDIO_FILE_PATH);59audioData = readWavFileAsFloatArray(AUDIO_FILE_PATH);60System.out.println("Audio file read successfully, samples: " + audioData.length);61System.out.println("Audio duration: " + String.format("%.2f", (audioData.length / (float) SAMPLE_RATE)) + " seconds");62} catch (Exception e) {63System.err.println("Failed to read audio file: " + e.getMessage());64e.printStackTrace();65return;66}6768// Get speech timestamps (batch mode, consistent with Python's get_speech_timestamps)69System.out.println("\nDetecting speech segments...");70List<Map<String, Integer>> speechTimestamps;71try {72speechTimestamps = getSpeechTimestamps(73audioData,74model,75THRESHOLD,76SAMPLE_RATE,77MIN_SPEECH_DURATION_MS,78MIN_SILENCE_DURATION_MS,79SPEECH_PAD_MS,80NEG_THRESHOLD81);82} catch (OrtException e) {83System.err.println("Failed to detect speech timestamps: " + e.getMessage());84e.printStackTrace();85return;86}8788// Output detection results89System.out.println("\nDetected speech timestamps (in samples):");90for (Map<String, Integer> timestamp : speechTimestamps) {91System.out.println(timestamp);92}9394// Output summary95System.out.println("\n" + "=".repeat(60));96System.out.println("Detection completed!");97System.out.println("Total detected " + speechTimestamps.size() + " speech segments");98System.out.println("=".repeat(60));99100// Close model101try {102model.close();103} catch (OrtException e) {104System.err.println("Error closing model: " + e.getMessage());105}106}107108/**109* Get speech timestamps110* Implements the same logic as Python's get_speech_timestamps111*112* @param audio Audio data (float array)113* @param model ONNX model114* @param threshold Speech threshold115* @param samplingRate Sampling rate116* @param minSpeechDurationMs Minimum speech duration (milliseconds)117* @param minSilenceDurationMs Minimum silence duration (milliseconds)118* @param speechPadMs Speech padding (milliseconds)119* @param negThreshold Negative threshold (used to determine speech end)120* @return List of speech timestamps121*/122private static List<Map<String, Integer>> getSpeechTimestamps(123float[] audio,124SlieroVadOnnxModel model,125float threshold,126int samplingRate,127int minSpeechDurationMs,128int minSilenceDurationMs,129int speechPadMs,130float negThreshold) throws OrtException {131132// Reset model states133model.resetStates();134135// Calculate parameters136int minSpeechSamples = samplingRate * minSpeechDurationMs / 1000;137int speechPadSamples = samplingRate * speechPadMs / 1000;138int minSilenceSamples = samplingRate * minSilenceDurationMs / 1000;139int windowSizeSamples = samplingRate == 16000 ? 512 : 256;140int audioLengthSamples = audio.length;141142// Calculate speech probabilities for all audio chunks143List<Float> speechProbs = new ArrayList<>();144for (int currentStart = 0; currentStart < audioLengthSamples; currentStart += windowSizeSamples) {145float[] chunk = new float[windowSizeSamples];146int chunkLength = Math.min(windowSizeSamples, audioLengthSamples - currentStart);147System.arraycopy(audio, currentStart, chunk, 0, chunkLength);148149// Pad with zeros if chunk is shorter than window size150if (chunkLength < windowSizeSamples) {151for (int i = chunkLength; i < windowSizeSamples; i++) {152chunk[i] = 0.0f;153}154}155156float speechProb = model.call(new float[][]{chunk}, samplingRate)[0];157speechProbs.add(speechProb);158}159160// Detect speech segments using the same algorithm as Python161boolean triggered = false;162List<Map<String, Integer>> speeches = new ArrayList<>();163Map<String, Integer> currentSpeech = null;164int tempEnd = 0;165166for (int i = 0; i < speechProbs.size(); i++) {167float speechProb = speechProbs.get(i);168169// Reset temporary end if speech probability exceeds threshold170if (speechProb >= threshold && tempEnd != 0) {171tempEnd = 0;172}173174// Detect speech start175if (speechProb >= threshold && !triggered) {176triggered = true;177currentSpeech = new HashMap<>();178currentSpeech.put("start", windowSizeSamples * i);179continue;180}181182// Detect speech end183if (speechProb < negThreshold && triggered) {184if (tempEnd == 0) {185tempEnd = windowSizeSamples * i;186}187if (windowSizeSamples * i - tempEnd < minSilenceSamples) {188continue;189} else {190currentSpeech.put("end", tempEnd);191if (currentSpeech.get("end") - currentSpeech.get("start") > minSpeechSamples) {192speeches.add(currentSpeech);193}194currentSpeech = null;195tempEnd = 0;196triggered = false;197}198}199}200201// Handle the last speech segment202if (currentSpeech != null &&203(audioLengthSamples - currentSpeech.get("start")) > minSpeechSamples) {204currentSpeech.put("end", audioLengthSamples);205speeches.add(currentSpeech);206}207208// Add speech padding - same logic as Python209for (int i = 0; i < speeches.size(); i++) {210Map<String, Integer> speech = speeches.get(i);211if (i == 0) {212speech.put("start", Math.max(0, speech.get("start") - speechPadSamples));213}214if (i != speeches.size() - 1) {215int silenceDuration = speeches.get(i + 1).get("start") - speech.get("end");216if (silenceDuration < 2 * speechPadSamples) {217speech.put("end", speech.get("end") + silenceDuration / 2);218speeches.get(i + 1).put("start",219Math.max(0, speeches.get(i + 1).get("start") - silenceDuration / 2));220} else {221speech.put("end", Math.min(audioLengthSamples, speech.get("end") + speechPadSamples));222speeches.get(i + 1).put("start",223Math.max(0, speeches.get(i + 1).get("start") - speechPadSamples));224}225} else {226speech.put("end", Math.min(audioLengthSamples, speech.get("end") + speechPadSamples));227}228}229230return speeches;231}232233/**234* Read WAV file and return as float array235*236* @param filePath WAV file path237* @return Audio data as float array (normalized to -1.0 to 1.0)238*/239private static float[] readWavFileAsFloatArray(String filePath)240throws UnsupportedAudioFileException, IOException {241File audioFile = new File(filePath);242AudioInputStream audioStream = AudioSystem.getAudioInputStream(audioFile);243244// Get audio format information245AudioFormat format = audioStream.getFormat();246System.out.println("Audio format: " + format);247248// Read all audio data249byte[] audioBytes = audioStream.readAllBytes();250audioStream.close();251252// Convert to float array253float[] audioData = new float[audioBytes.length / 2];254for (int i = 0; i < audioData.length; i++) {255// 16-bit PCM: two bytes per sample (little-endian)256short sample = (short) ((audioBytes[i * 2] & 0xff) | (audioBytes[i * 2 + 1] << 8));257audioData[i] = sample / 32768.0f; // Normalize to -1.0 to 1.0258}259260return audioData;261}262263}264265266