Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/training/dataset/base_dataset.py
694 views
1
import torch
2
from utils.hparams import hparams
3
import numpy as np
4
import os
5
6
class BaseDataset(torch.utils.data.Dataset):
7
'''
8
Base class for datasets.
9
1. *ordered_indices*:
10
if self.shuffle == True, shuffle the indices;
11
if self.sort_by_len == True, sort data by length;
12
2. *sizes*:
13
clipped length if "max_frames" is set;
14
3. *num_tokens*:
15
unclipped length.
16
17
Subclasses should define:
18
1. *collate*:
19
take the longest data, pad other data to the same length;
20
2. *__getitem__*:
21
the index function.
22
'''
23
def __init__(self, shuffle):
24
super().__init__()
25
self.hparams = hparams
26
self.shuffle = shuffle
27
self.sort_by_len = hparams['sort_by_len']
28
self.sizes = None
29
30
@property
31
def _sizes(self):
32
return self.sizes
33
34
def __getitem__(self, index):
35
raise NotImplementedError
36
37
def collater(self, samples):
38
raise NotImplementedError
39
40
def __len__(self):
41
return len(self._sizes)
42
43
def num_tokens(self, index):
44
return self.size(index)
45
46
def size(self, index):
47
"""Return an example's size as a float or tuple. This value is used when
48
filtering a dataset with ``--max-positions``."""
49
size = min(self._sizes[index], hparams['max_frames'])
50
return size
51
52
def ordered_indices(self):
53
"""Return an ordered list of indices. Batches will be constructed based
54
on this order."""
55
if self.shuffle:
56
indices = np.random.permutation(len(self))
57
if self.sort_by_len:
58
indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')]
59
# 先random, 然后稳定排序, 保证排序后同长度的数据顺序是依照random permutation的 (被其随机打乱).
60
else:
61
indices = np.arange(len(self))
62
return indices
63
64
@property
65
def num_workers(self):
66
return int(os.getenv('NUM_WORKERS', hparams['ds_workers']))
67
68