Path: blob/master/examples/cpp_libtorch/silero_torch.cc
1179 views
//Author : Nathan Lee1//Created On : 2024-11-182//Description : silero 5.1 system for torch-script(c++).3//Version : 1.0456#include "silero_torch.h"78namespace silero {910VadIterator::VadIterator(const std::string &model_path, float threshold, int sample_rate, int window_size_ms, int speech_pad_ms, int min_silence_duration_ms, int min_speech_duration_ms, int max_duration_merge_ms, bool print_as_samples)11:sample_rate(sample_rate), threshold(threshold), window_size_ms(window_size_ms), speech_pad_ms(speech_pad_ms), min_silence_duration_ms(min_silence_duration_ms), min_speech_duration_ms(min_speech_duration_ms), max_duration_merge_ms(max_duration_merge_ms), print_as_samples(print_as_samples)12{13init_torch_model(model_path);14//init_engine(window_size_ms);15}16VadIterator::~VadIterator(){17}181920void VadIterator::SpeechProbs(std::vector<float>& input_wav){21// Set the sample rate (must match the model's expected sample rate)22// Process the waveform in chunks of 512 samples23int num_samples = input_wav.size();24int num_chunks = num_samples / window_size_samples;25int remainder_samples = num_samples % window_size_samples;2627total_sample_size += num_samples;2829torch::Tensor output;30std::vector<torch::Tensor> chunks;3132for (int i = 0; i < num_chunks; i++) {3334float* chunk_start = input_wav.data() + i *window_size_samples;35torch::Tensor chunk = torch::from_blob(chunk_start, {1,window_size_samples}, torch::kFloat32);36//std::cout<<"chunk size : "<<chunk.sizes()<<std::endl;37chunks.push_back(chunk);383940if(i==num_chunks-1 && remainder_samples>0){//마지막 chunk && 나머지가 존재41int remaining_samples = num_samples - num_chunks * window_size_samples;42//std::cout<<"Remainder size : "<<remaining_samples;43float* chunk_start_remainder = input_wav.data() + num_chunks *window_size_samples;4445torch::Tensor remainder_chunk = torch::from_blob(chunk_start_remainder, {1,remaining_samples},46torch::kFloat32);47// Pad the remainder chunk to match window_size_samples48torch::Tensor padded_chunk = torch::cat({remainder_chunk, torch::zeros({1, window_size_samples49- remaining_samples}, torch::kFloat32)}, 1);50//std::cout<<", padded_chunk size : "<<padded_chunk.size(1)<<std::endl;5152chunks.push_back(padded_chunk);53}54}5556if (!chunks.empty()) {5758#ifdef USE_BATCH59torch::Tensor batched_chunks = torch::stack(chunks); // Stack all chunks into a single tensor60//batched_chunks = batched_chunks.squeeze(1);61batched_chunks = torch::cat({batched_chunks.squeeze(1)});6263#ifdef USE_GPU64batched_chunks = batched_chunks.to(at::kCUDA); // Move the entire batch to GPU once65#endif66// Prepare input for model67std::vector<torch::jit::IValue> inputs;68inputs.push_back(batched_chunks); // Batch of chunks69inputs.push_back(sample_rate); // Assuming sample_rate is a valid input for the model7071// Run inference on the batch72torch::NoGradGuard no_grad;73torch::Tensor output = model.forward(inputs).toTensor();74#ifdef USE_GPU75output = output.to(at::kCPU); // Move the output back to CPU once76#endif77// Collect output probabilities78for (int i = 0; i < chunks.size(); i++) {79float output_f = output[i].item<float>();80outputs_prob.push_back(output_f);81//std::cout << "Chunk " << i << " prob: " << output_f<< "\n";82}83#else8485std::vector<torch::Tensor> outputs;86torch::Tensor batched_chunks = torch::stack(chunks);87#ifdef USE_GPU88batched_chunks = batched_chunks.to(at::kCUDA);89#endif90for (int i = 0; i < chunks.size(); i++) {91torch::NoGradGuard no_grad;92std::vector<torch::jit::IValue> inputs;93inputs.push_back(batched_chunks[i]);94inputs.push_back(sample_rate);9596torch::Tensor output = model.forward(inputs).toTensor();97outputs.push_back(output);98}99torch::Tensor all_outputs = torch::stack(outputs);100#ifdef USE_GPU101all_outputs = all_outputs.to(at::kCPU);102#endif103for (int i = 0; i < chunks.size(); i++) {104float output_f = all_outputs[i].item<float>();105outputs_prob.push_back(output_f);106}107108109110#endif111112}113114115}116117118std::vector<SpeechSegment> VadIterator::GetSpeechTimestamps() {119std::vector<SpeechSegment> speeches = DoVad();120121#ifdef USE_BATCH122//When you use BATCH inference. You would better use 'mergeSpeeches' function to arrage time stamp.123//It could be better get reasonable output because of distorted probs.124duration_merge_samples = sample_rate * max_duration_merge_ms / 1000;125std::vector<SpeechSegment> speeches_merge = mergeSpeeches(speeches, duration_merge_samples);126if(!print_as_samples){127for (auto& speech : speeches_merge) { //samples to second128speech.start /= sample_rate;129speech.end /= sample_rate;130}131}132133return speeches_merge;134#else135136if(!print_as_samples){137for (auto& speech : speeches) { //samples to second138speech.start /= sample_rate;139speech.end /= sample_rate;140}141}142143return speeches;144145#endif146147}148void VadIterator::SetVariables(){149init_engine(window_size_ms);150}151152void VadIterator::init_engine(int window_size_ms) {153min_silence_samples = sample_rate * min_silence_duration_ms / 1000;154speech_pad_samples = sample_rate * speech_pad_ms / 1000;155window_size_samples = sample_rate / 1000 * window_size_ms;156min_speech_samples = sample_rate * min_speech_duration_ms / 1000;157}158159void VadIterator::init_torch_model(const std::string& model_path) {160at::set_num_threads(1);161model = torch::jit::load(model_path);162163#ifdef USE_GPU164if (!torch::cuda::is_available()) {165std::cout<<"CUDA is not available! Please check your GPU settings"<<std::endl;166throw std::runtime_error("CUDA is not available!");167model.to(at::Device(at::kCPU));168169} else {170std::cout<<"CUDA available! Running on '0'th GPU"<<std::endl;171model.to(at::Device(at::kCUDA, 0)); //select 0'th machine172}173#endif174175176model.eval();177torch::NoGradGuard no_grad;178std::cout << "Model loaded successfully"<<std::endl;179}180181void VadIterator::reset_states() {182triggered = false;183current_sample = 0;184temp_end = 0;185outputs_prob.clear();186model.run_method("reset_states");187total_sample_size = 0;188}189190std::vector<SpeechSegment> VadIterator::DoVad() {191std::vector<SpeechSegment> speeches;192193for (size_t i = 0; i < outputs_prob.size(); ++i) {194float speech_prob = outputs_prob[i];195//std::cout << speech_prob << std::endl;196//std::cout << "Chunk " << i << " Prob: " << speech_prob << "\n";197//std::cout << speech_prob << " ";198current_sample += window_size_samples;199200if (speech_prob >= threshold && temp_end != 0) {201temp_end = 0;202}203204if (speech_prob >= threshold && !triggered) {205triggered = true;206SpeechSegment segment;207segment.start = std::max(static_cast<int>(0), current_sample - speech_pad_samples - window_size_samples);208speeches.push_back(segment);209continue;210}211212if (speech_prob < threshold - 0.15f && triggered) {213if (temp_end == 0) {214temp_end = current_sample;215}216217if (current_sample - temp_end < min_silence_samples) {218continue;219} else {220SpeechSegment& segment = speeches.back();221segment.end = temp_end + speech_pad_samples - window_size_samples;222temp_end = 0;223triggered = false;224}225}226}227228if (triggered) { //만약 낮은 확률을 보이다가 마지막프레임 prbos만 딱 확률이 높게 나오면 위에서 triggerd = true 메핑과 동시에 segment start가 돼서 문제가 될것 같은데? start = end 같은값? 후처리가 있으니 문제가 없으려나?229std::cout<<"when last triggered is keep working until last Probs"<<std::endl;230SpeechSegment& segment = speeches.back();231segment.end = total_sample_size; // 현재 샘플을 마지막 구간의 종료 시간으로 설정232triggered = false; // VAD 상태 초기화233}234235speeches.erase(236std::remove_if(237speeches.begin(),238speeches.end(),239[this](const SpeechSegment& speech) {240return ((speech.end - this->speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);241//min_speech_samples is 4000samples(0.25sec)242//여기서 포인트!! 계산 할때는 start,end sample에'speech_pad_samples' 사이즈를 추가한후 길이를 측정함.243}244),245speeches.end()246);247248249//std::cout<<std::endl;250//std::cout<<"outputs_prob.size : "<<outputs_prob.size()<<std::endl;251252reset_states();253return speeches;254}255256std::vector<SpeechSegment> VadIterator::mergeSpeeches(const std::vector<SpeechSegment>& speeches, int duration_merge_samples) {257std::vector<SpeechSegment> mergedSpeeches;258259if (speeches.empty()) {260return mergedSpeeches; // 빈 벡터 반환261}262263// 첫 번째 구간으로 초기화264SpeechSegment currentSegment = speeches[0];265266for (size_t i = 1; i < speeches.size(); ++i) { //첫번째 start,end 정보 건너뛰기. 그래서 i=1부터267// 두 구간의 차이가 threshold(duration_merge_samples)보다 작은 경우, 합침268if (speeches[i].start - currentSegment.end < duration_merge_samples) {269// 현재 구간의 끝점을 업데이트270currentSegment.end = speeches[i].end;271} else {272// 차이가 threshold(duration_merge_samples) 이상이면 현재 구간을 저장하고 새로운 구간 시작273mergedSpeeches.push_back(currentSegment);274currentSegment = speeches[i];275}276}277278// 마지막 구간 추가279mergedSpeeches.push_back(currentSegment);280281return mergedSpeeches;282}283284}285286287