Path: blob/master/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java
1898 views
package org.example;12import ai.onnxruntime.OnnxTensor;3import ai.onnxruntime.OrtEnvironment;4import ai.onnxruntime.OrtException;5import ai.onnxruntime.OrtSession;6import java.util.Arrays;7import java.util.HashMap;8import java.util.List;9import java.util.Map;1011/**12* Silero VAD ONNX Model Wrapper13*14* @author VvvvvGH15*/16public class SlieroVadOnnxModel {17// ONNX runtime session18private final OrtSession session;19// Model state - dimensions: [2, batch_size, 128]20private float[][][] state;21// Context - stores the tail of the previous audio chunk22private float[][] context;23// Last sample rate24private int lastSr = 0;25// Last batch size26private int lastBatchSize = 0;27// Supported sample rates28private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);2930// Constructor31public SlieroVadOnnxModel(String modelPath) throws OrtException {32// Get the ONNX runtime environment33OrtEnvironment env = OrtEnvironment.getEnvironment();34// Create ONNX session options35OrtSession.SessionOptions opts = new OrtSession.SessionOptions();36// Set InterOp thread count to 1 (for parallel processing of different graph operations)37opts.setInterOpNumThreads(1);38// Set IntraOp thread count to 1 (for parallel processing within a single operation)39opts.setIntraOpNumThreads(1);40// Enable CPU execution optimization41opts.addCPU(true);42// Create ONNX session with the environment, model path, and options43session = env.createSession(modelPath, opts);44// Reset states45resetStates();46}4748/**49* Reset states with default batch size50*/51void resetStates() {52resetStates(1);53}5455/**56* Reset states with specific batch size57*58* @param batchSize Batch size for state initialization59*/60void resetStates(int batchSize) {61state = new float[2][batchSize][128];62context = new float[0][]; // Empty context63lastSr = 0;64lastBatchSize = 0;65}6667public void close() throws OrtException {68session.close();69}7071/**72* Inner class for validation result73*/74public static class ValidationResult {75public final float[][] x;76public final int sr;7778public ValidationResult(float[][] x, int sr) {79this.x = x;80this.sr = sr;81}82}8384/**85* Validate input data86*87* @param x Audio data array88* @param sr Sample rate89* @return Validated input data and sample rate90*/91private ValidationResult validateInput(float[][] x, int sr) {92// Ensure input is at least 2D93if (x.length == 1) {94x = new float[][]{x[0]};95}96// Check if input dimension is valid97if (x.length > 2) {98throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);99}100101// Downsample if sample rate is a multiple of 16000102if (sr != 16000 && (sr % 16000 == 0)) {103int step = sr / 16000;104float[][] reducedX = new float[x.length][];105106for (int i = 0; i < x.length; i++) {107float[] current = x[i];108float[] newArr = new float[(current.length + step - 1) / step];109110for (int j = 0, index = 0; j < current.length; j += step, index++) {111newArr[index] = current[j];112}113114reducedX[i] = newArr;115}116117x = reducedX;118sr = 16000;119}120121// Validate sample rate122if (!SAMPLE_RATES.contains(sr)) {123throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");124}125126// Check if audio chunk is too short127if (((float) sr) / x[0].length > 31.25) {128throw new IllegalArgumentException("Input audio is too short");129}130131return new ValidationResult(x, sr);132}133134/**135* Call the ONNX model for inference136*137* @param x Audio data array138* @param sr Sample rate139* @return Speech probability output140* @throws OrtException If ONNX runtime error occurs141*/142public float[] call(float[][] x, int sr) throws OrtException {143ValidationResult result = validateInput(x, sr);144x = result.x;145sr = result.sr;146147int batchSize = x.length;148int numSamples = sr == 16000 ? 512 : 256;149int contextSize = sr == 16000 ? 64 : 32;150151// Reset states only when sample rate or batch size changes152if (lastSr != 0 && lastSr != sr) {153resetStates(batchSize);154} else if (lastBatchSize != 0 && lastBatchSize != batchSize) {155resetStates(batchSize);156} else if (lastBatchSize == 0) {157// First call - state is already initialized, just set batch size158lastBatchSize = batchSize;159}160161// Initialize context if needed162if (context.length == 0) {163context = new float[batchSize][contextSize];164}165166// Concatenate context and input167float[][] xWithContext = new float[batchSize][contextSize + numSamples];168for (int i = 0; i < batchSize; i++) {169// Copy context170System.arraycopy(context[i], 0, xWithContext[i], 0, contextSize);171// Copy input172System.arraycopy(x[i], 0, xWithContext[i], contextSize, numSamples);173}174175OrtEnvironment env = OrtEnvironment.getEnvironment();176177OnnxTensor inputTensor = null;178OnnxTensor stateTensor = null;179OnnxTensor srTensor = null;180OrtSession.Result ortOutputs = null;181182try {183// Create input tensors184inputTensor = OnnxTensor.createTensor(env, xWithContext);185stateTensor = OnnxTensor.createTensor(env, state);186srTensor = OnnxTensor.createTensor(env, new long[]{sr});187188Map<String, OnnxTensor> inputs = new HashMap<>();189inputs.put("input", inputTensor);190inputs.put("sr", srTensor);191inputs.put("state", stateTensor);192193// Run ONNX model inference194ortOutputs = session.run(inputs);195// Get output results196float[][] output = (float[][]) ortOutputs.get(0).getValue();197state = (float[][][]) ortOutputs.get(1).getValue();198199// Update context - save the last contextSize samples from input200for (int i = 0; i < batchSize; i++) {201System.arraycopy(xWithContext[i], xWithContext[i].length - contextSize,202context[i], 0, contextSize);203}204205lastSr = sr;206lastBatchSize = batchSize;207return output[0];208} finally {209if (inputTensor != null) {210inputTensor.close();211}212if (stateTensor != null) {213stateTensor.close();214}215if (srTensor != null) {216srTensor.close();217}218if (ortOutputs != null) {219ortOutputs.close();220}221}222}223}224225226