Path: blob/master/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java
1171 views
package org.example;12import ai.onnxruntime.OnnxTensor;3import ai.onnxruntime.OrtEnvironment;4import ai.onnxruntime.OrtException;5import ai.onnxruntime.OrtSession;6import java.util.Arrays;7import java.util.HashMap;8import java.util.List;9import java.util.Map;1011public class SlieroVadOnnxModel {12// Define private variable OrtSession13private final OrtSession session;14private float[][][] h;15private float[][][] c;16// Define the last sample rate17private int lastSr = 0;18// Define the last batch size19private int lastBatchSize = 0;20// Define a list of supported sample rates21private static final List<Integer> SAMPLE_RATES = Arrays.asList(8000, 16000);2223// Constructor24public SlieroVadOnnxModel(String modelPath) throws OrtException {25// Get the ONNX runtime environment26OrtEnvironment env = OrtEnvironment.getEnvironment();27// Create an ONNX session options object28OrtSession.SessionOptions opts = new OrtSession.SessionOptions();29// Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations30opts.setInterOpNumThreads(1);31// Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation32opts.setIntraOpNumThreads(1);33// Add a CPU device, setting to false disables CPU execution optimization34opts.addCPU(true);35// Create an ONNX session using the environment, model path, and options36session = env.createSession(modelPath, opts);37// Reset states38resetStates();39}4041/**42* Reset states43*/44void resetStates() {45h = new float[2][1][64];46c = new float[2][1][64];47lastSr = 0;48lastBatchSize = 0;49}5051public void close() throws OrtException {52session.close();53}5455/**56* Define inner class ValidationResult57*/58public static class ValidationResult {59public final float[][] x;60public final int sr;6162// Constructor63public ValidationResult(float[][] x, int sr) {64this.x = x;65this.sr = sr;66}67}6869/**70* Function to validate input data71*/72private ValidationResult validateInput(float[][] x, int sr) {73// Process the input data with dimension 174if (x.length == 1) {75x = new float[][]{x[0]};76}77// Throw an exception when the input data dimension is greater than 278if (x.length > 2) {79throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);80}8182// Process the input data when the sample rate is not equal to 16000 and is a multiple of 1600083if (sr != 16000 && (sr % 16000 == 0)) {84int step = sr / 16000;85float[][] reducedX = new float[x.length][];8687for (int i = 0; i < x.length; i++) {88float[] current = x[i];89float[] newArr = new float[(current.length + step - 1) / step];9091for (int j = 0, index = 0; j < current.length; j += step, index++) {92newArr[index] = current[j];93}9495reducedX[i] = newArr;96}9798x = reducedX;99sr = 16000;100}101102// If the sample rate is not in the list of supported sample rates, throw an exception103if (!SAMPLE_RATES.contains(sr)) {104throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");105}106107// If the input audio block is too short, throw an exception108if (((float) sr) / x[0].length > 31.25) {109throw new IllegalArgumentException("Input audio is too short");110}111112// Return the validated result113return new ValidationResult(x, sr);114}115116/**117* Method to call the ONNX model118*/119public float[] call(float[][] x, int sr) throws OrtException {120ValidationResult result = validateInput(x, sr);121x = result.x;122sr = result.sr;123124int batchSize = x.length;125126if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) {127resetStates();128}129130OrtEnvironment env = OrtEnvironment.getEnvironment();131132OnnxTensor inputTensor = null;133OnnxTensor hTensor = null;134OnnxTensor cTensor = null;135OnnxTensor srTensor = null;136OrtSession.Result ortOutputs = null;137138try {139// Create input tensors140inputTensor = OnnxTensor.createTensor(env, x);141hTensor = OnnxTensor.createTensor(env, h);142cTensor = OnnxTensor.createTensor(env, c);143srTensor = OnnxTensor.createTensor(env, new long[]{sr});144145Map<String, OnnxTensor> inputs = new HashMap<>();146inputs.put("input", inputTensor);147inputs.put("sr", srTensor);148inputs.put("h", hTensor);149inputs.put("c", cTensor);150151// Call the ONNX model for calculation152ortOutputs = session.run(inputs);153// Get the output results154float[][] output = (float[][]) ortOutputs.get(0).getValue();155h = (float[][][]) ortOutputs.get(1).getValue();156c = (float[][][]) ortOutputs.get(2).getValue();157158lastSr = sr;159lastBatchSize = batchSize;160return output[0];161} finally {162if (inputTensor != null) {163inputTensor.close();164}165if (hTensor != null) {166hTensor.close();167}168if (cTensor != null) {169cTensor.close();170}171if (srTensor != null) {172srTensor.close();173}174if (ortOutputs != null) {175ortOutputs.close();176}177}178}179}180181182