Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
snakers4
GitHub Repository: snakers4/silero-vad
Path: blob/master/examples/c++/silero.h
1908 views
1
#ifndef SILERO_H
2
#define SILERO_H
3
4
// silero.h
5
// Author : NathanJHLee
6
// Created On : 2025-11-10
7
// Description : silero 6.2 system for onnx-runtime(c++) and torch-script(c++)
8
// Version : 1.3
9
10
#include <string>
11
#include <vector>
12
#include <iostream>
13
#include <fstream>
14
#include <chrono>
15
#include <algorithm>
16
#include <cstring>
17
18
#ifdef USE_TORCH
19
#include <torch/torch.h>
20
#include <torch/script.h>
21
#elif USE_ONNX
22
#include "onnxruntime_cxx_api.h"
23
#endif
24
25
namespace silero {
26
27
struct Interval {
28
float start;
29
float end;
30
int numberOfSubseg;
31
32
void initialize() {
33
start = 0;
34
end = 0;
35
numberOfSubseg = 0;
36
}
37
};
38
39
class VadIterator {
40
public:
41
VadIterator(const std::string &model_path,
42
float threshold = 0.5,
43
int sample_rate = 16000,
44
int window_size_ms = 32,
45
int speech_pad_ms = 30,
46
int min_silence_duration_ms = 100,
47
int min_speech_duration_ms = 250,
48
int max_duration_merge_ms = 300,
49
bool print_as_samples = false);
50
~VadIterator();
51
52
// Batch (non-streaming) interface (for backward compatibility)
53
void SpeechProbs(std::vector<float>& input_wav);
54
std::vector<Interval> GetSpeechTimestamps();
55
void SetVariables();
56
57
// Public parameters (can be modified by user)
58
float threshold;
59
int sample_rate;
60
int window_size_ms;
61
int min_speech_duration_ms;
62
int max_duration_merge_ms;
63
bool print_as_samples;
64
65
private:
66
#ifdef USE_TORCH
67
torch::jit::script::Module model;
68
void init_torch_model(const std::string& model_path);
69
#elif USE_ONNX
70
Ort::Env env; // 환경 객체
71
Ort::SessionOptions session_options; // 세션 옵션
72
std::shared_ptr<Ort::Session> session; // ONNX 세션
73
Ort::AllocatorWithDefaultOptions allocator; // 기본 할당자
74
Ort::MemoryInfo memory_info; // 메모리 정보 (CPU)
75
76
void init_onnx_model(const std::string& model_path);
77
float predict(const std::vector<float>& data_chunk);
78
79
//const int context_samples; // 예: 64 samples
80
int context_samples; // 예: 64 samples
81
std::vector<float> _context; // 초기값 모두 0
82
int effective_window_size;
83
84
// ONNX 입력/출력 관련 버퍼 및 노드 이름들
85
std::vector<Ort::Value> ort_inputs;
86
std::vector<const char*> input_node_names;
87
std::vector<float> input;
88
unsigned int size_state; // 고정값: 2*1*128
89
std::vector<float> _state;
90
std::vector<int64_t> sr;
91
int64_t input_node_dims[2]; // [1, effective_window_size]
92
const int64_t state_node_dims[3]; // [ 2, 1, 128 ]
93
const int64_t sr_node_dims[1]; // [ 1 ]
94
std::vector<Ort::Value> ort_outputs;
95
std::vector<const char*> output_node_names; // 기본값: [ "output", "stateN" ]
96
#endif
97
std::vector<float> outputs_prob; // used in batch mode
98
int min_silence_samples;
99
int min_speech_samples;
100
int speech_pad_samples;
101
int window_size_samples;
102
int duration_merge_samples;
103
int current_sample = 0;
104
int total_sample_size = 0;
105
int min_silence_duration_ms;
106
int speech_pad_ms;
107
bool triggered = false;
108
int temp_end = 0;
109
int global_end = 0;
110
int erase_tail_count = 0;
111
112
113
void init_engine(int window_size_ms);
114
void reset_states();
115
std::vector<Interval> DoVad();
116
117
118
};
119
120
} // namespace silero
121
122
#endif // SILERO_H
123
124
125