Path: blob/master/examples/java-wav-file-example/src/main/java/org/example/SileroVadOnnxModel.java
1171 views
package org.example;12import ai.onnxruntime.OnnxTensor;3import ai.onnxruntime.OrtEnvironment;4import ai.onnxruntime.OrtException;5import ai.onnxruntime.OrtSession;6import java.util.Arrays;7import java.util.HashMap;8import java.util.List;9import java.util.Map;1011public class SileroVadOnnxModel {12// Define private variable OrtSession13private final OrtSession session;14private float[][][] state;15private float[][] context;16// Define the last sample rate17private int lastSr = 0;18// Define the last batch size19private int lastBatchSize = 0;20// Define a list of supported sample rates21private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);2223// Constructor24public SileroVadOnnxModel(String modelPath) throws OrtException {25// Get the ONNX runtime environment26OrtEnvironment env = OrtEnvironment.getEnvironment();27// Create an ONNX session options object28OrtSession.SessionOptions opts = new OrtSession.SessionOptions();29// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations30opts.setInterOpNumThreads(1);31// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation32opts.setIntraOpNumThreads(1);33// Add a CPU device, setting to false disables CPU execution optimization34opts.addCPU(true);35// Create an ONNX session using the environment, model path, and options36session = env.createSession(modelPath, opts);37// Reset states38resetStates();39}4041/**42* Reset states43*/44void resetStates() {45state = new float[2][1][128];46context = new float[0][];47lastSr = 0;48lastBatchSize = 0;49}5051public void close() throws OrtException {52session.close();53}5455/**56* Define inner class ValidationResult57*/58public static class ValidationResult {59public final float[][] x;60public final int sr;6162// Constructor63public ValidationResult(float[][] x, int sr) {64this.x = x;65this.sr = sr;66}67}6869/**70* Function to validate input data71*/72private ValidationResult validateInput(float[][] x, int sr) {73// Process the input data with dimension 174if (x.length == 1) {75x = new float[][]{x[0]};76}77// Throw an exception when the input data dimension is greater than 278if (x.length > 2) {79throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);80}8182// Process the input data when the sample rate is not equal to 16000 and is a multiple of 1600083if (sr != 16000 && (sr % 16000 == 0)) {84int step = sr / 16000;85float[][] reducedX = new float[x.length][];8687for (int i = 0; i < x.length; i++) {88float[] current = x[i];89float[] newArr = new float[(current.length + step - 1) / step];9091for (int j = 0, index = 0; j < current.length; j += step, index++) {92newArr[index] = current[j];93}9495reducedX[i] = newArr;96}9798x = reducedX;99sr = 16000;100}101102// If the sample rate is not in the list of supported sample rates, throw an exception103if (!SAMPLE_RATES.contains(sr)) {104throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");105}106107// If the input audio block is too short, throw an exception108if (((float) sr) / x[0].length > 31.25) {109throw new IllegalArgumentException("Input audio is too short");110}111112// Return the validated result113return new ValidationResult(x, sr);114}115116private static float[][] concatenate(float[][] a, float[][] b) {117if (a.length != b.length) {118throw new IllegalArgumentException("The number of rows in both arrays must be the same.");119}120121int rows = a.length;122int colsA = a[0].length;123int colsB = b[0].length;124float[][] result = new float[rows][colsA + colsB];125126for (int i = 0; i < rows; i++) {127System.arraycopy(a[i], 0, result[i], 0, colsA);128System.arraycopy(b[i], 0, result[i], colsA, colsB);129}130131return result;132}133134private static float[][] getLastColumns(float[][] array, int contextSize) {135int rows = array.length;136int cols = array[0].length;137138if (contextSize > cols) {139throw new IllegalArgumentException("contextSize cannot be greater than the number of columns in the array.");140}141142float[][] result = new float[rows][contextSize];143144for (int i = 0; i < rows; i++) {145System.arraycopy(array[i], cols - contextSize, result[i], 0, contextSize);146}147148return result;149}150151/**152* Method to call the ONNX model153*/154public float[] call(float[][] x, int sr) throws OrtException {155ValidationResult result = validateInput(x, sr);156x = result.x;157sr = result.sr;158int numberSamples = 256;159if (sr == 16000) {160numberSamples = 512;161}162163if (x[0].length != numberSamples) {164throw new IllegalArgumentException("Provided number of samples is " + x[0].length + " (Supported values: 256 for 8000 sample rate, 512 for 16000)");165}166167int batchSize = x.length;168169int contextSize = 32;170if (sr == 16000) {171contextSize = 64;172}173174if (lastBatchSize == 0) {175resetStates();176}177if (lastSr != 0 && lastSr != sr) {178resetStates();179}180if (lastBatchSize != 0 && lastBatchSize != batchSize) {181resetStates();182}183184if (context.length == 0) {185context = new float[batchSize][contextSize];186}187188x = concatenate(context, x);189190OrtEnvironment env = OrtEnvironment.getEnvironment();191192OnnxTensor inputTensor = null;193OnnxTensor stateTensor = null;194OnnxTensor srTensor = null;195OrtSession.Result ortOutputs = null;196197try {198// Create input tensors199inputTensor = OnnxTensor.createTensor(env, x);200stateTensor = OnnxTensor.createTensor(env, state);201srTensor = OnnxTensor.createTensor(env, new long[]{sr});202203Map<String, OnnxTensor> inputs = new HashMap<>();204inputs.put("input", inputTensor);205inputs.put("sr", srTensor);206inputs.put("state", stateTensor);207208// Call the ONNX model for calculation209ortOutputs = session.run(inputs);210// Get the output results211float[][] output = (float[][]) ortOutputs.get(0).getValue();212state = (float[][][]) ortOutputs.get(1).getValue();213214context = getLastColumns(x, contextSize);215lastSr = sr;216lastBatchSize = batchSize;217return output[0];218} finally {219if (inputTensor != null) {220inputTensor.close();221}222if (stateTensor != null) {223stateTensor.close();224}225if (srTensor != null) {226srTensor.close();227}228if (ortOutputs != null) {229ortOutputs.close();230}231}232}233}234235236