Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/nsf_hifigan/utils.py
694 views
1
import glob
2
import os
3
import matplotlib
4
import torch
5
from torch.nn.utils import weight_norm
6
matplotlib.use("Agg")
7
import matplotlib.pylab as plt
8
9
10
def plot_spectrogram(spectrogram):
11
fig, ax = plt.subplots(figsize=(10, 2))
12
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13
interpolation='none')
14
plt.colorbar(im, ax=ax)
15
16
fig.canvas.draw()
17
plt.close()
18
19
return fig
20
21
22
def init_weights(m, mean=0.0, std=0.01):
23
classname = m.__class__.__name__
24
if classname.find("Conv") != -1:
25
m.weight.data.normal_(mean, std)
26
27
28
def apply_weight_norm(m):
29
classname = m.__class__.__name__
30
if classname.find("Conv") != -1:
31
weight_norm(m)
32
33
34
def get_padding(kernel_size, dilation=1):
35
return int((kernel_size*dilation - dilation)/2)
36
37
38
def load_checkpoint(filepath, device):
39
assert os.path.isfile(filepath)
40
print("Loading '{}'".format(filepath))
41
checkpoint_dict = torch.load(filepath, map_location=device)
42
print("Complete.")
43
return checkpoint_dict
44
45
46
def save_checkpoint(filepath, obj):
47
print("Saving checkpoint to {}".format(filepath))
48
torch.save(obj, filepath)
49
print("Complete.")
50
51
52
def del_old_checkpoints(cp_dir, prefix, n_models=2):
53
pattern = os.path.join(cp_dir, prefix + '????????')
54
cp_list = glob.glob(pattern) # get checkpoint paths
55
cp_list = sorted(cp_list)# sort by iter
56
if len(cp_list) > n_models: # if more than n_models models are found
57
for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
58
open(cp, 'w').close()# empty file contents
59
os.unlink(cp)# delete file (move to trash when using Colab)
60
61
62
def scan_checkpoint(cp_dir, prefix):
63
pattern = os.path.join(cp_dir, prefix + '????????')
64
cp_list = glob.glob(pattern)
65
if len(cp_list) == 0:
66
return None
67
return sorted(cp_list)[-1]
68