Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/examples/cpp/silero-vad-onnx.cpp
1179 views
1
#ifndef _CRT_SECURE_NO_WARNINGS
2
#define _CRT_SECURE_NO_WARNINGS
3
#endif
4
5
#include <iostream>
6
#include <vector>
7
#include <sstream>
8
#include <cstring>
9
#include <limits>
10
#include <chrono>
11
#include <iomanip>
12
#include <memory>
13
#include <string>
14
#include <stdexcept>
15
#include <cstdio>
16
#include <cstdarg>
17
#include <cmath> // for std::rint
18
#if __cplusplus < 201703L
19
#include <memory>
20
#endif
21
22
//#define __DEBUG_SPEECH_PROB___
23
24
#include "onnxruntime_cxx_api.h"
25
#include "wav.h" // For reading WAV files
26
27
// timestamp_t class: stores the start and end (in samples) of a speech segment.
28
class timestamp_t {
29
public:
30
int start;
31
int end;
32
33
timestamp_t(int start = -1, int end = -1)
34
: start(start), end(end) { }
35
36
timestamp_t& operator=(const timestamp_t& a) {
37
start = a.start;
38
end = a.end;
39
return *this;
40
}
41
42
bool operator==(const timestamp_t& a) const {
43
return (start == a.start && end == a.end);
44
}
45
46
// Returns a formatted string of the timestamp.
47
std::string c_str() const {
48
return format("{start:%08d, end:%08d}", start, end);
49
}
50
private:
51
// Helper function for formatting.
52
std::string format(const char* fmt, ...) const {
53
char buf[256];
54
va_list args;
55
va_start(args, fmt);
56
const auto r = std::vsnprintf(buf, sizeof(buf), fmt, args);
57
va_end(args);
58
if (r < 0)
59
return {};
60
const size_t len = r;
61
if (len < sizeof(buf))
62
return std::string(buf, len);
63
#if __cplusplus >= 201703L
64
std::string s(len, '\0');
65
va_start(args, fmt);
66
std::vsnprintf(s.data(), len + 1, fmt, args);
67
va_end(args);
68
return s;
69
#else
70
auto vbuf = std::unique_ptr<char[]>(new char[len + 1]);
71
va_start(args, fmt);
72
std::vsnprintf(vbuf.get(), len + 1, fmt, args);
73
va_end(args);
74
return std::string(vbuf.get(), len);
75
#endif
76
}
77
};
78
79
// VadIterator class: uses ONNX Runtime to detect speech segments.
80
class VadIterator {
81
private:
82
// ONNX Runtime resources
83
Ort::Env env;
84
Ort::SessionOptions session_options;
85
std::shared_ptr<Ort::Session> session = nullptr;
86
Ort::AllocatorWithDefaultOptions allocator;
87
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
88
89
// ----- Context-related additions -----
90
const int context_samples = 64; // For 16kHz, 64 samples are added as context.
91
std::vector<float> _context; // Holds the last 64 samples from the previous chunk (initialized to zero).
92
93
// Original window size (e.g., 32ms corresponds to 512 samples)
94
int window_size_samples;
95
// Effective window size = window_size_samples + context_samples
96
int effective_window_size;
97
98
// Additional declaration: samples per millisecond
99
int sr_per_ms;
100
101
// ONNX Runtime input/output buffers
102
std::vector<Ort::Value> ort_inputs;
103
std::vector<const char*> input_node_names = { "input", "state", "sr" };
104
std::vector<float> input;
105
unsigned int size_state = 2 * 1 * 128;
106
std::vector<float> _state;
107
std::vector<int64_t> sr;
108
int64_t input_node_dims[2] = {};
109
const int64_t state_node_dims[3] = { 2, 1, 128 };
110
const int64_t sr_node_dims[1] = { 1 };
111
std::vector<Ort::Value> ort_outputs;
112
std::vector<const char*> output_node_names = { "output", "stateN" };
113
114
// Model configuration parameters
115
int sample_rate;
116
float threshold;
117
int min_silence_samples;
118
int min_silence_samples_at_max_speech;
119
int min_speech_samples;
120
float max_speech_samples;
121
int speech_pad_samples;
122
int audio_length_samples;
123
124
// State management
125
bool triggered = false;
126
unsigned int temp_end = 0;
127
unsigned int current_sample = 0;
128
int prev_end;
129
int next_start = 0;
130
std::vector<timestamp_t> speeches;
131
timestamp_t current_speech;
132
133
// Loads the ONNX model.
134
void init_onnx_model(const std::wstring& model_path) {
135
init_engine_threads(1, 1);
136
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
137
}
138
139
// Initializes threading settings.
140
void init_engine_threads(int inter_threads, int intra_threads) {
141
session_options.SetIntraOpNumThreads(intra_threads);
142
session_options.SetInterOpNumThreads(inter_threads);
143
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
144
}
145
146
// Resets internal state (_state, _context, etc.)
147
void reset_states() {
148
std::memset(_state.data(), 0, _state.size() * sizeof(float));
149
triggered = false;
150
temp_end = 0;
151
current_sample = 0;
152
prev_end = next_start = 0;
153
speeches.clear();
154
current_speech = timestamp_t();
155
std::fill(_context.begin(), _context.end(), 0.0f);
156
}
157
158
// Inference: runs inference on one chunk of input data.
159
// data_chunk is expected to have window_size_samples samples.
160
void predict(const std::vector<float>& data_chunk) {
161
// Build new input: first context_samples from _context, followed by the current chunk (window_size_samples).
162
std::vector<float> new_data(effective_window_size, 0.0f);
163
std::copy(_context.begin(), _context.end(), new_data.begin());
164
std::copy(data_chunk.begin(), data_chunk.end(), new_data.begin() + context_samples);
165
input = new_data;
166
167
// Create input tensor (input_node_dims[1] is already set to effective_window_size).
168
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
169
memory_info, input.data(), input.size(), input_node_dims, 2);
170
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
171
memory_info, _state.data(), _state.size(), state_node_dims, 3);
172
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
173
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
174
ort_inputs.clear();
175
ort_inputs.emplace_back(std::move(input_ort));
176
ort_inputs.emplace_back(std::move(state_ort));
177
ort_inputs.emplace_back(std::move(sr_ort));
178
179
// Run inference.
180
ort_outputs = session->Run(
181
Ort::RunOptions{ nullptr },
182
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
183
output_node_names.data(), output_node_names.size());
184
185
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
186
float* stateN = ort_outputs[1].GetTensorMutableData<float>();
187
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
188
current_sample += static_cast<unsigned int>(window_size_samples); // Advance by the original window size.
189
190
// If speech is detected (probability >= threshold)
191
if (speech_prob >= threshold) {
192
#ifdef __DEBUG_SPEECH_PROB___
193
float speech = current_sample - window_size_samples;
194
printf("{ start: %.3f s (%.3f) %08d}\n", 1.0f * speech / sample_rate, speech_prob, current_sample - window_size_samples);
195
#endif
196
if (temp_end != 0) {
197
temp_end = 0;
198
if (next_start < prev_end)
199
next_start = current_sample - window_size_samples;
200
}
201
if (!triggered) {
202
triggered = true;
203
current_speech.start = current_sample - window_size_samples;
204
}
205
// Update context: copy the last context_samples from new_data.
206
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
207
return;
208
}
209
210
// If the speech segment becomes too long.
211
if (triggered && ((current_sample - current_speech.start) > max_speech_samples)) {
212
if (prev_end > 0) {
213
current_speech.end = prev_end;
214
speeches.push_back(current_speech);
215
current_speech = timestamp_t();
216
if (next_start < prev_end)
217
triggered = false;
218
else
219
current_speech.start = next_start;
220
prev_end = 0;
221
next_start = 0;
222
temp_end = 0;
223
}
224
else {
225
current_speech.end = current_sample;
226
speeches.push_back(current_speech);
227
current_speech = timestamp_t();
228
prev_end = 0;
229
next_start = 0;
230
temp_end = 0;
231
triggered = false;
232
}
233
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
234
return;
235
}
236
237
if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold)) {
238
// When the speech probability temporarily drops but is still in speech, update context without changing state.
239
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
240
return;
241
}
242
243
if (speech_prob < (threshold - 0.15)) {
244
#ifdef __DEBUG_SPEECH_PROB___
245
float speech = current_sample - window_size_samples - speech_pad_samples;
246
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0f * speech / sample_rate, speech_prob, current_sample - window_size_samples);
247
#endif
248
if (triggered) {
249
if (temp_end == 0)
250
temp_end = current_sample;
251
if (current_sample - temp_end > min_silence_samples_at_max_speech)
252
prev_end = temp_end;
253
if ((current_sample - temp_end) >= min_silence_samples) {
254
current_speech.end = temp_end;
255
if (current_speech.end - current_speech.start > min_speech_samples) {
256
speeches.push_back(current_speech);
257
current_speech = timestamp_t();
258
prev_end = 0;
259
next_start = 0;
260
temp_end = 0;
261
triggered = false;
262
}
263
}
264
}
265
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
266
return;
267
}
268
}
269
270
public:
271
// Process the entire audio input.
272
void process(const std::vector<float>& input_wav) {
273
reset_states();
274
audio_length_samples = static_cast<int>(input_wav.size());
275
// Process audio in chunks of window_size_samples (e.g., 512 samples)
276
for (size_t j = 0; j < static_cast<size_t>(audio_length_samples); j += static_cast<size_t>(window_size_samples)) {
277
if (j + static_cast<size_t>(window_size_samples) > static_cast<size_t>(audio_length_samples))
278
break;
279
std::vector<float> chunk(&input_wav[j], &input_wav[j] + window_size_samples);
280
predict(chunk);
281
}
282
if (current_speech.start >= 0) {
283
current_speech.end = audio_length_samples;
284
speeches.push_back(current_speech);
285
current_speech = timestamp_t();
286
prev_end = 0;
287
next_start = 0;
288
temp_end = 0;
289
triggered = false;
290
}
291
}
292
293
// Returns the detected speech timestamps.
294
const std::vector<timestamp_t> get_speech_timestamps() const {
295
return speeches;
296
}
297
298
// Public method to reset the internal state.
299
void reset() {
300
reset_states();
301
}
302
303
public:
304
// Constructor: sets model path, sample rate, window size (ms), and other parameters.
305
// The parameters are set to match the Python version.
306
VadIterator(const std::wstring ModelPath,
307
int Sample_rate = 16000, int windows_frame_size = 32,
308
float Threshold = 0.5, int min_silence_duration_ms = 100,
309
int speech_pad_ms = 30, int min_speech_duration_ms = 250,
310
float max_speech_duration_s = std::numeric_limits<float>::infinity())
311
: sample_rate(Sample_rate), threshold(Threshold), speech_pad_samples(speech_pad_ms), prev_end(0)
312
{
313
sr_per_ms = sample_rate / 1000; // e.g., 16000 / 1000 = 16
314
window_size_samples = windows_frame_size * sr_per_ms; // e.g., 32ms * 16 = 512 samples
315
effective_window_size = window_size_samples + context_samples; // e.g., 512 + 64 = 576 samples
316
input_node_dims[0] = 1;
317
input_node_dims[1] = effective_window_size;
318
_state.resize(size_state);
319
sr.resize(1);
320
sr[0] = sample_rate;
321
_context.assign(context_samples, 0.0f);
322
min_speech_samples = sr_per_ms * min_speech_duration_ms;
323
max_speech_samples = (sample_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples);
324
min_silence_samples = sr_per_ms * min_silence_duration_ms;
325
min_silence_samples_at_max_speech = sr_per_ms * 98;
326
init_onnx_model(ModelPath);
327
}
328
};
329
330
int main() {
331
// Read the WAV file (expects 16000 Hz, mono, PCM).
332
wav::WavReader wav_reader("audio/recorder.wav"); // File located in the "audio" folder.
333
int numSamples = wav_reader.num_samples();
334
std::vector<float> input_wav(static_cast<size_t>(numSamples));
335
for (size_t i = 0; i < static_cast<size_t>(numSamples); i++) {
336
input_wav[i] = static_cast<float>(*(wav_reader.data() + i));
337
}
338
339
// Set the ONNX model path (file located in the "model" folder).
340
std::wstring model_path = L"model/silero_vad.onnx";
341
342
// Initialize the VadIterator.
343
VadIterator vad(model_path);
344
345
// Process the audio.
346
vad.process(input_wav);
347
348
// Retrieve the speech timestamps (in samples).
349
std::vector<timestamp_t> stamps = vad.get_speech_timestamps();
350
351
// Convert timestamps to seconds and round to one decimal place (for 16000 Hz).
352
const float sample_rate_float = 16000.0f;
353
for (size_t i = 0; i < stamps.size(); i++) {
354
float start_sec = std::rint((stamps[i].start / sample_rate_float) * 10.0f) / 10.0f;
355
float end_sec = std::rint((stamps[i].end / sample_rate_float) * 10.0f) / 10.0f;
356
std::cout << "Speech detected from "
357
<< std::fixed << std::setprecision(1) << start_sec
358
<< " s to "
359
<< std::fixed << std::setprecision(1) << end_sec
360
<< " s" << std::endl;
361
}
362
363
// Optionally, reset the internal state.
364
vad.reset();
365
366
return 0;
367
}
368
369