#ifndef SILERO_H
#define SILERO_H
#include <string>
#include <vector>
#include <iostream>
#include <fstream>
#include <chrono>
#include <algorithm>
#include <cstring>
#ifdef USE_TORCH
#include <torch/torch.h>
#include <torch/script.h>
#elif USE_ONNX
#include "onnxruntime_cxx_api.h"
#endif
namespace silero {
struct Interval {
float start;
float end;
int numberOfSubseg;
void initialize() {
start = 0;
end = 0;
numberOfSubseg = 0;
}
};
class VadIterator {
public:
VadIterator(const std::string &model_path,
float threshold = 0.5,
int sample_rate = 16000,
int window_size_ms = 32,
int speech_pad_ms = 30,
int min_silence_duration_ms = 100,
int min_speech_duration_ms = 250,
int max_duration_merge_ms = 300,
bool print_as_samples = false);
~VadIterator();
void SpeechProbs(std::vector<float>& input_wav);
std::vector<Interval> GetSpeechTimestamps();
void SetVariables();
float threshold;
int sample_rate;
int window_size_ms;
int min_speech_duration_ms;
int max_duration_merge_ms;
bool print_as_samples;
private:
#ifdef USE_TORCH
torch::jit::script::Module model;
void init_torch_model(const std::string& model_path);
#elif USE_ONNX
Ort::Env env;
Ort::SessionOptions session_options;
std::shared_ptr<Ort::Session> session;
Ort::AllocatorWithDefaultOptions allocator;
Ort::MemoryInfo memory_info;
void init_onnx_model(const std::string& model_path);
float predict(const std::vector<float>& data_chunk);
int context_samples;
std::vector<float> _context;
int effective_window_size;
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_node_names;
std::vector<float> input;
unsigned int size_state;
std::vector<float> _state;
std::vector<int64_t> sr;
int64_t input_node_dims[2];
const int64_t state_node_dims[3];
const int64_t sr_node_dims[1];
std::vector<Ort::Value> ort_outputs;
std::vector<const char*> output_node_names;
#endif
std::vector<float> outputs_prob;
int min_silence_samples;
int min_speech_samples;
int speech_pad_samples;
int window_size_samples;
int duration_merge_samples;
int current_sample = 0;
int total_sample_size = 0;
int min_silence_duration_ms;
int speech_pad_ms;
bool triggered = false;
int temp_end = 0;
int global_end = 0;
int erase_tail_count = 0;
void init_engine(int window_size_ms);
void reset_states();
std::vector<Interval> DoVad();
};
}
#endif