Path: blob/master/examples/csharp/SileroVadOnnxModel.cs
1171 views
using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Collections.Generic; using System.Linq; namespace VADdotnet; public class SileroVadOnnxModel : IDisposable { private readonly InferenceSession session; private float[][][] state; private float[][] context; private int lastSr = 0; private int lastBatchSize = 0; private static readonly List<int> SAMPLE_RATES = new List<int> { 8000, 16000 }; public SileroVadOnnxModel(string modelPath) { var sessionOptions = new SessionOptions(); sessionOptions.InterOpNumThreads = 1; sessionOptions.IntraOpNumThreads = 1; sessionOptions.EnableCpuMemArena = true; session = new InferenceSession(modelPath, sessionOptions); ResetStates(); } public void ResetStates() { state = new float[2][][]; state[0] = new float[1][]; state[1] = new float[1][]; state[0][0] = new float[128]; state[1][0] = new float[128]; context = Array.Empty<float[]>(); lastSr = 0; lastBatchSize = 0; } public void Dispose() { session?.Dispose(); } public class ValidationResult { public float[][] X { get; } public int Sr { get; } public ValidationResult(float[][] x, int sr) { X = x; Sr = sr; } } private ValidationResult ValidateInput(float[][] x, int sr) { if (x.Length == 1) { x = new float[][] { x[0] }; } if (x.Length > 2) { throw new ArgumentException($"Incorrect audio data dimension: {x[0].Length}"); } if (sr != 16000 && (sr % 16000 == 0)) { int step = sr / 16000; float[][] reducedX = new float[x.Length][]; for (int i = 0; i < x.Length; i++) { float[] current = x[i]; float[] newArr = new float[(current.Length + step - 1) / step]; for (int j = 0, index = 0; j < current.Length; j += step, index++) { newArr[index] = current[j]; } reducedX[i] = newArr; } x = reducedX; sr = 16000; } if (!SAMPLE_RATES.Contains(sr)) { throw new ArgumentException($"Only supports sample rates {string.Join(", ", SAMPLE_RATES)} (or multiples of 16000)"); } if (((float)sr) / x[0].Length > 31.25) { throw new ArgumentException("Input audio is too short"); } return new ValidationResult(x, sr); } private static float[][] Concatenate(float[][] a, float[][] b) { if (a.Length != b.Length) { throw new ArgumentException("The number of rows in both arrays must be the same."); } int rows = a.Length; int colsA = a[0].Length; int colsB = b[0].Length; float[][] result = new float[rows][]; for (int i = 0; i < rows; i++) { result[i] = new float[colsA + colsB]; Array.Copy(a[i], 0, result[i], 0, colsA); Array.Copy(b[i], 0, result[i], colsA, colsB); } return result; } private static float[][] GetLastColumns(float[][] array, int contextSize) { int rows = array.Length; int cols = array[0].Length; if (contextSize > cols) { throw new ArgumentException("contextSize cannot be greater than the number of columns in the array."); } float[][] result = new float[rows][]; for (int i = 0; i < rows; i++) { result[i] = new float[contextSize]; Array.Copy(array[i], cols - contextSize, result[i], 0, contextSize); } return result; } public float[] Call(float[][] x, int sr) { var result = ValidateInput(x, sr); x = result.X; sr = result.Sr; int numberSamples = sr == 16000 ? 512 : 256; if (x[0].Length != numberSamples) { throw new ArgumentException($"Provided number of samples is {x[0].Length} (Supported values: 256 for 8000 sample rate, 512 for 16000)"); } int batchSize = x.Length; int contextSize = sr == 16000 ? 64 : 32; if (lastBatchSize == 0) { ResetStates(); } if (lastSr != 0 && lastSr != sr) { ResetStates(); } if (lastBatchSize != 0 && lastBatchSize != batchSize) { ResetStates(); } if (context.Length == 0) { context = new float[batchSize][]; for (int i = 0; i < batchSize; i++) { context[i] = new float[contextSize]; } } x = Concatenate(context, x); var inputs = new List<NamedOnnxValue> { NamedOnnxValue.CreateFromTensor("input", new DenseTensor<float>(x.SelectMany(a => a).ToArray(), new[] { x.Length, x[0].Length })), NamedOnnxValue.CreateFromTensor("sr", new DenseTensor<long>(new[] { (long)sr }, new[] { 1 })), NamedOnnxValue.CreateFromTensor("state", new DenseTensor<float>(state.SelectMany(a => a.SelectMany(b => b)).ToArray(), new[] { state.Length, state[0].Length, state[0][0].Length })) }; using (var outputs = session.Run(inputs)) { var output = outputs.First(o => o.Name == "output").AsTensor<float>(); var newState = outputs.First(o => o.Name == "stateN").AsTensor<float>(); context = GetLastColumns(x, contextSize); lastSr = sr; lastBatchSize = batchSize; state = new float[newState.Dimensions[0]][][]; for (int i = 0; i < newState.Dimensions[0]; i++) { state[i] = new float[newState.Dimensions[1]][]; for (int j = 0; j < newState.Dimensions[1]; j++) { state[i][j] = new float[newState.Dimensions[2]]; for (int k = 0; k < newState.Dimensions[2]; k++) { state[i][j][k] = newState[i, j, k]; } } } return output.ToArray(); } } }