Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/examples/c++/silero.cc
1902 views
1
// silero.cc
2
// Author : NathanJHLee
3
// Created On : 2025-11-10
4
// Description : silero 6.2 system for onnx-runtime(c++) and torch-script(c++)
5
// Version : 1.3
6
7
#include "silero.h"
8
9
10
namespace silero {
11
12
#ifdef USE_TORCH
13
VadIterator::VadIterator(const std::string &model_path,
14
float threshold,
15
int sample_rate,
16
int window_size_ms,
17
int speech_pad_ms,
18
int min_silence_duration_ms,
19
int min_speech_duration_ms,
20
int max_duration_merge_ms,
21
bool print_as_samples)
22
: threshold(threshold), sample_rate(sample_rate), window_size_ms(window_size_ms),
23
speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms),
24
min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms),
25
print_as_samples(print_as_samples)
26
{
27
init_torch_model(model_path);
28
}
29
30
VadIterator::~VadIterator(){
31
}
32
33
34
void VadIterator::init_torch_model(const std::string& model_path) {
35
at::set_num_threads(1);
36
model = torch::jit::load(model_path);
37
38
model.eval();
39
torch::NoGradGuard no_grad;
40
std::cout<<"Silero libtorch-Model loaded successfully"<<std::endl;
41
}
42
43
void VadIterator::SpeechProbs(std::vector<float>& input_wav) {
44
int num_samples = input_wav.size();
45
int num_chunks = num_samples / window_size_samples;
46
int remainder_samples = num_samples % window_size_samples;
47
total_sample_size += num_samples;
48
49
std::vector<torch::Tensor> chunks;
50
51
for (int i = 0; i < num_chunks; i++) {
52
float* chunk_start = input_wav.data() + i * window_size_samples;
53
torch::Tensor chunk = torch::from_blob(chunk_start, {1, window_size_samples}, torch::kFloat32);
54
chunks.push_back(chunk);
55
56
if (i == num_chunks - 1 && remainder_samples > 0) {
57
int remaining_samples = num_samples - num_chunks * window_size_samples;
58
float* chunk_start_remainder = input_wav.data() + num_chunks * window_size_samples;
59
torch::Tensor remainder_chunk = torch::from_blob(chunk_start_remainder, {1, remaining_samples}, torch::kFloat32);
60
torch::Tensor padded_chunk = torch::cat({remainder_chunk, torch::zeros({1, window_size_samples - remaining_samples}, torch::kFloat32)}, 1);
61
chunks.push_back(padded_chunk);
62
}
63
}
64
65
if (!chunks.empty()) {
66
std::vector<torch::Tensor> outputs;
67
torch::Tensor batched_chunks = torch::stack(chunks);
68
for (size_t i = 0; i < chunks.size(); i++) {
69
torch::NoGradGuard no_grad;
70
std::vector<torch::jit::IValue> inputs;
71
inputs.push_back(batched_chunks[i]);
72
inputs.push_back(sample_rate);
73
torch::Tensor output = model.forward(inputs).toTensor();
74
outputs.push_back(output);
75
}
76
torch::Tensor all_outputs = torch::stack(outputs);
77
for (size_t i = 0; i < chunks.size(); i++) {
78
float output_f = all_outputs[i].item<float>();
79
outputs_prob.push_back(output_f);
80
//////To print Probs by libtorch
81
//std::cout << "Chunk " << i << " prob: " << output_f<< "\n";
82
}
83
}
84
}
85
86
87
#elif USE_ONNX
88
89
VadIterator::VadIterator(const std::string &model_path,
90
float threshold,
91
int sample_rate,
92
int window_size_ms,
93
int speech_pad_ms,
94
int min_silence_duration_ms,
95
int min_speech_duration_ms,
96
int max_duration_merge_ms,
97
bool print_as_samples)
98
:sample_rate(sample_rate), threshold(threshold), window_size_ms(window_size_ms),
99
speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms),
100
min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms),
101
print_as_samples(print_as_samples),
102
env(ORT_LOGGING_LEVEL_ERROR, "Vad"), session_options(), session(nullptr), allocator(),
103
memory_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU)), context_samples(64),
104
_context(64, 0.0f), current_sample(0), size_state(2 * 1 * 128),
105
input_node_names({"input", "state", "sr"}), output_node_names({"output", "stateN"}),
106
state_node_dims{2, 1, 128}, sr_node_dims{1}
107
108
{
109
init_onnx_model(model_path);
110
}
111
VadIterator::~VadIterator(){
112
}
113
114
void VadIterator::init_onnx_model(const std::string& model_path) {
115
int inter_threads=1;
116
int intra_threads=1;
117
session_options.SetIntraOpNumThreads(intra_threads);
118
session_options.SetInterOpNumThreads(inter_threads);
119
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
120
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
121
std::cout<<"Silero onnx-Model loaded successfully"<<std::endl;
122
}
123
124
float VadIterator::predict(const std::vector<float>& data_chunk) {
125
// _context와 현재 청크를 결합하여 입력 데이터 구성
126
std::vector<float> new_data(effective_window_size, 0.0f);
127
std::copy(_context.begin(), _context.end(), new_data.begin());
128
std::copy(data_chunk.begin(), data_chunk.end(), new_data.begin() + context_samples);
129
input = new_data;
130
131
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
132
memory_info, input.data(), input.size(), input_node_dims, 2);
133
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
134
memory_info, _state.data(), _state.size(), state_node_dims, 3);
135
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
136
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
137
ort_inputs.clear();
138
ort_inputs.push_back(std::move(input_ort));
139
ort_inputs.push_back(std::move(state_ort));
140
ort_inputs.push_back(std::move(sr_ort));
141
142
ort_outputs = session->Run(
143
Ort::RunOptions{ nullptr },
144
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
145
output_node_names.data(), output_node_names.size());
146
147
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0]; // ONNX 출력: 첫 번째 값이 음성 확률
148
149
float* stateN = ort_outputs[1].GetTensorMutableData<float>(); // 두 번째 출력값: 상태 업데이트
150
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
151
152
std::copy(new_data.end() - context_samples, new_data.end(), _context.begin());
153
// _context 업데이트: new_data의 마지막 context_samples 유지
154
155
return speech_prob;
156
}
157
void VadIterator::SpeechProbs(std::vector<float>& input_wav) {
158
reset_states();
159
total_sample_size = static_cast<int>(input_wav.size());
160
for (size_t j = 0; j < static_cast<size_t>(total_sample_size); j += window_size_samples) {
161
if (j + window_size_samples > static_cast<size_t>(total_sample_size))
162
break;
163
std::vector<float> chunk(input_wav.begin() + j, input_wav.begin() + j + window_size_samples);
164
float speech_prob = predict(chunk);
165
outputs_prob.push_back(speech_prob);
166
}
167
}
168
169
#endif
170
171
void VadIterator::reset_states() {
172
triggered = false;
173
current_sample = 0;
174
temp_end = 0;
175
outputs_prob.clear();
176
total_sample_size = 0;
177
178
#ifdef USE_TORCH
179
model.run_method("reset_states"); // Reset model states if applicable
180
#elif USE_ONNX
181
std::memset(_state.data(), 0, _state.size() * sizeof(float));
182
std::fill(_context.begin(), _context.end(), 0.0f);
183
#endif
184
}
185
186
std::vector<Interval> VadIterator::GetSpeechTimestamps() {
187
std::vector<Interval> speeches = DoVad();
188
if(!print_as_samples){
189
for (auto& speech : speeches) {
190
speech.start /= sample_rate;
191
speech.end /= sample_rate;
192
}
193
}
194
return speeches;
195
}
196
197
void VadIterator::SetVariables(){
198
// Initialize internal engine parameters
199
init_engine(window_size_ms);
200
}
201
202
void VadIterator::init_engine(int window_size_ms) {
203
min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
204
speech_pad_samples = sample_rate * speech_pad_ms / 1000;
205
window_size_samples = sample_rate / 1000 * window_size_ms;
206
min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
207
#ifdef USE_ONNX
208
//for ONNX
209
context_samples=window_size_samples / 8;
210
_context.assign(context_samples, 0.0f);
211
212
effective_window_size = window_size_samples + context_samples; // 예: 512 + 64 = 576 samples
213
input_node_dims[0] = 1;
214
input_node_dims[1] = effective_window_size;
215
_state.resize(size_state);
216
sr.resize(1);
217
sr[0] = sample_rate;
218
#endif
219
}
220
221
std::vector<Interval> VadIterator::DoVad() {
222
std::vector<Interval> speeches;
223
for (size_t i = 0; i < outputs_prob.size(); ++i) {
224
float speech_prob = outputs_prob[i];
225
current_sample += window_size_samples;
226
if (speech_prob >= threshold && temp_end != 0) {
227
temp_end = 0;
228
}
229
230
if (speech_prob >= threshold) {
231
if (!triggered) {
232
triggered = true;
233
Interval segment;
234
segment.start = std::max(0, current_sample - speech_pad_samples - window_size_samples);
235
speeches.push_back(segment);
236
}
237
}else {
238
if (triggered) {
239
if (speech_prob < threshold - 0.15f) {
240
if (temp_end == 0) {
241
temp_end = current_sample;
242
}
243
if (current_sample - temp_end >= min_silence_samples) {
244
Interval& segment = speeches.back();
245
segment.end = temp_end + speech_pad_samples - window_size_samples;
246
temp_end = 0;
247
triggered = false;
248
}
249
}
250
}
251
}
252
253
254
}
255
256
if (triggered) {
257
std::cout<<"Finalizing active speech segment at stream end."<<std::endl;
258
Interval& segment = speeches.back();
259
segment.end = total_sample_size;
260
triggered = false;
261
}
262
speeches.erase(std::remove_if(speeches.begin(), speeches.end(),
263
[this](const Interval& speech) {
264
return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
265
}), speeches.end());
266
267
reset_states();
268
return speeches;
269
}
270
271
272
} // namespace silero
273
274
275