Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/preprocessing/hubertinfer.py
694 views
1
import os.path
2
from io import BytesIO
3
from pathlib import Path
4
5
import numpy as np
6
import torch
7
8
from network.hubert.hubert_model import hubert_soft, get_units
9
from network.hubert.vec_model import load_model, get_vec_units
10
from utils.hparams import hparams
11
12
13
class Hubertencoder():
14
def __init__(self, pt_path='checkpoints/hubert/hubert_soft.pt'):
15
if not 'use_vec' in hparams.keys():
16
hparams['use_vec'] = False
17
if hparams['use_vec']:
18
pt_path = "checkpoints/vec/checkpoint_best_legacy_500.pt"
19
self.dev = torch.device("cuda")
20
self.hbt_model = load_model(pt_path)
21
else:
22
pt_path = list(Path(pt_path).parent.rglob('*.pt'))[0]
23
if 'hubert_gpu' in hparams.keys():
24
self.use_gpu = hparams['hubert_gpu']
25
else:
26
self.use_gpu = True
27
self.dev = torch.device("cuda" if self.use_gpu and torch.cuda.is_available() else "cpu")
28
self.hbt_model = hubert_soft(str(pt_path)).to(self.dev)
29
30
def encode(self, wav_path):
31
if isinstance(wav_path, BytesIO):
32
npy_path = ""
33
wav_path.seek(0)
34
else:
35
npy_path = Path(wav_path).with_suffix('.npy')
36
if os.path.exists(npy_path):
37
units = np.load(str(npy_path))
38
elif hparams['use_vec']:
39
units = get_vec_units(self.hbt_model, wav_path, self.dev).cpu().numpy()[0]
40
else:
41
units = get_units(self.hbt_model, wav_path, self.dev).cpu().numpy()[0]
42
return units # [T,256]
43
44