Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/training/task/tts.py
694 views
1
from multiprocessing.pool import Pool
2
3
import matplotlib
4
5
from utils.pl_utils import data_loader
6
from utils.training_utils import RSQRTSchedule
7
from network.vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
8
from modules.fastspeech.pe import PitchExtractor
9
10
matplotlib.use('Agg')
11
import os
12
import numpy as np
13
from tqdm import tqdm
14
import torch.distributed as dist
15
16
from training.task.base_task import BaseTask
17
from utils.hparams import hparams
18
from utils.text_encoder import TokenTextEncoder
19
import json
20
from preprocessing.hubertinfer import Hubertencoder
21
import torch
22
import torch.optim
23
import torch.utils.data
24
import utils
25
26
27
28
class TtsTask(BaseTask):
29
def __init__(self, *args, **kwargs):
30
self.vocoder = None
31
self.phone_encoder = Hubertencoder(hparams['hubert_path'])
32
# self.padding_idx = self.phone_encoder.pad()
33
# self.eos_idx = self.phone_encoder.eos()
34
# self.seg_idx = self.phone_encoder.seg()
35
self.saving_result_pool = None
36
self.saving_results_futures = None
37
self.stats = {}
38
super().__init__(*args, **kwargs)
39
40
def build_scheduler(self, optimizer):
41
return RSQRTSchedule(optimizer)
42
43
def build_optimizer(self, model):
44
self.optimizer = optimizer = torch.optim.AdamW(
45
model.parameters(),
46
lr=hparams['lr'])
47
return optimizer
48
49
def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None,
50
required_batch_size_multiple=-1, endless=False, batch_by_size=True):
51
devices_cnt = torch.cuda.device_count()
52
if devices_cnt == 0:
53
devices_cnt = 1
54
if required_batch_size_multiple == -1:
55
required_batch_size_multiple = devices_cnt
56
57
def shuffle_batches(batches):
58
np.random.shuffle(batches)
59
return batches
60
61
if max_tokens is not None:
62
max_tokens *= devices_cnt
63
if max_sentences is not None:
64
max_sentences *= devices_cnt
65
indices = dataset.ordered_indices()
66
if batch_by_size:
67
batch_sampler = utils.batch_by_size(
68
indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
69
required_batch_size_multiple=required_batch_size_multiple,
70
)
71
else:
72
batch_sampler = []
73
for i in range(0, len(indices), max_sentences):
74
batch_sampler.append(indices[i:i + max_sentences])
75
76
if shuffle:
77
batches = shuffle_batches(list(batch_sampler))
78
if endless:
79
batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))]
80
else:
81
batches = batch_sampler
82
if endless:
83
batches = [b for _ in range(1000) for b in batches]
84
num_workers = dataset.num_workers
85
if self.trainer.use_ddp:
86
num_replicas = dist.get_world_size()
87
rank = dist.get_rank()
88
batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0]
89
return torch.utils.data.DataLoader(dataset,
90
collate_fn=dataset.collater,
91
batch_sampler=batches,
92
num_workers=num_workers,
93
pin_memory=False)
94
95
# def build_phone_encoder(self, data_dir):
96
# phone_list_file = os.path.join(data_dir, 'phone_set.json')
97
98
# phone_list = json.load(open(phone_list_file, encoding='utf-8'))
99
# return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
100
101
def build_optimizer(self, model):
102
self.optimizer = optimizer = torch.optim.AdamW(
103
model.parameters(),
104
lr=hparams['lr'])
105
return optimizer
106
107
def test_start(self):
108
self.saving_result_pool = Pool(8)
109
self.saving_results_futures = []
110
self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
111
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
112
self.pe = PitchExtractor().cuda()
113
utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
114
self.pe.eval()
115
def test_end(self, outputs):
116
self.saving_result_pool.close()
117
[f.get() for f in tqdm(self.saving_results_futures)]
118
self.saving_result_pool.join()
119
return {}
120
121
##########
122
# utils
123
##########
124
def weights_nonzero_speech(self, target):
125
# target : B x T x mel
126
# Assign weight 1.0 to all labels except for padding (id=0).
127
dim = target.size(-1)
128
return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
129
130
if __name__ == '__main__':
131
TtsTask.start()
132
133