Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/helpers/datasets.py
4918 views
1
import random
2
from pathlib import PurePath, Path
3
from typing import List, Callable, Dict, Optional
4
5
from torchvision import datasets, transforms
6
7
import torch
8
from labml import lab
9
from labml import monit
10
from labml.configs import BaseConfigs
11
from labml.configs import aggregate, option
12
from labml.utils.download import download_file
13
from torch.utils.data import DataLoader
14
from torch.utils.data import IterableDataset, Dataset
15
16
17
def _mnist_dataset(is_train, transform):
18
return datasets.MNIST(str(lab.get_data_path()),
19
train=is_train,
20
download=True,
21
transform=transform)
22
23
24
class MNISTConfigs(BaseConfigs):
25
"""
26
Configurable MNIST data set.
27
28
Arguments:
29
dataset_name (str): name of the data set, ``MNIST``
30
dataset_transforms (torchvision.transforms.Compose): image transformations
31
train_dataset (torchvision.datasets.MNIST): training dataset
32
valid_dataset (torchvision.datasets.MNIST): validation dataset
33
34
train_loader (torch.utils.data.DataLoader): training data loader
35
valid_loader (torch.utils.data.DataLoader): validation data loader
36
37
train_batch_size (int): training batch size
38
valid_batch_size (int): validation batch size
39
40
train_loader_shuffle (bool): whether to shuffle training data
41
valid_loader_shuffle (bool): whether to shuffle validation data
42
"""
43
44
dataset_name: str = 'MNIST'
45
dataset_transforms: transforms.Compose
46
train_dataset: datasets.MNIST
47
valid_dataset: datasets.MNIST
48
49
train_loader: DataLoader
50
valid_loader: DataLoader
51
52
train_batch_size: int = 64
53
valid_batch_size: int = 1024
54
55
train_loader_shuffle: bool = True
56
valid_loader_shuffle: bool = False
57
58
59
@option(MNISTConfigs.dataset_transforms)
60
def mnist_transforms():
61
return transforms.Compose([
62
transforms.ToTensor(),
63
transforms.Normalize((0.1307,), (0.3081,))
64
])
65
66
67
@option(MNISTConfigs.train_dataset)
68
def mnist_train_dataset(c: MNISTConfigs):
69
return _mnist_dataset(True, c.dataset_transforms)
70
71
72
@option(MNISTConfigs.valid_dataset)
73
def mnist_valid_dataset(c: MNISTConfigs):
74
return _mnist_dataset(False, c.dataset_transforms)
75
76
77
@option(MNISTConfigs.train_loader)
78
def mnist_train_loader(c: MNISTConfigs):
79
return DataLoader(c.train_dataset,
80
batch_size=c.train_batch_size,
81
shuffle=c.train_loader_shuffle)
82
83
84
@option(MNISTConfigs.valid_loader)
85
def mnist_valid_loader(c: MNISTConfigs):
86
return DataLoader(c.valid_dataset,
87
batch_size=c.valid_batch_size,
88
shuffle=c.valid_loader_shuffle)
89
90
91
aggregate(MNISTConfigs.dataset_name, 'MNIST',
92
(MNISTConfigs.dataset_transforms, 'mnist_transforms'),
93
(MNISTConfigs.train_dataset, 'mnist_train_dataset'),
94
(MNISTConfigs.valid_dataset, 'mnist_valid_dataset'),
95
(MNISTConfigs.train_loader, 'mnist_train_loader'),
96
(MNISTConfigs.valid_loader, 'mnist_valid_loader'))
97
98
99
def _cifar_dataset(is_train, transform):
100
return datasets.CIFAR10(str(lab.get_data_path()),
101
train=is_train,
102
download=True,
103
transform=transform)
104
105
106
class CIFAR10Configs(BaseConfigs):
107
"""
108
Configurable CIFAR 10 data set.
109
110
Arguments:
111
dataset_name (str): name of the data set, ``CIFAR10``
112
dataset_transforms (torchvision.transforms.Compose): image transformations
113
train_dataset (torchvision.datasets.CIFAR10): training dataset
114
valid_dataset (torchvision.datasets.CIFAR10): validation dataset
115
116
train_loader (torch.utils.data.DataLoader): training data loader
117
valid_loader (torch.utils.data.DataLoader): validation data loader
118
119
train_batch_size (int): training batch size
120
valid_batch_size (int): validation batch size
121
122
train_loader_shuffle (bool): whether to shuffle training data
123
valid_loader_shuffle (bool): whether to shuffle validation data
124
"""
125
dataset_name: str = 'CIFAR10'
126
dataset_transforms: transforms.Compose
127
train_dataset: datasets.CIFAR10
128
valid_dataset: datasets.CIFAR10
129
130
train_loader: DataLoader
131
valid_loader: DataLoader
132
133
train_batch_size: int = 64
134
valid_batch_size: int = 1024
135
136
train_loader_shuffle: bool = True
137
valid_loader_shuffle: bool = False
138
139
140
@CIFAR10Configs.calc(CIFAR10Configs.dataset_transforms)
141
def cifar10_transforms():
142
return transforms.Compose([
143
transforms.ToTensor(),
144
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
145
])
146
147
148
@CIFAR10Configs.calc(CIFAR10Configs.train_dataset)
149
def cifar10_train_dataset(c: CIFAR10Configs):
150
return _cifar_dataset(True, c.dataset_transforms)
151
152
153
@CIFAR10Configs.calc(CIFAR10Configs.valid_dataset)
154
def cifar10_valid_dataset(c: CIFAR10Configs):
155
return _cifar_dataset(False, c.dataset_transforms)
156
157
158
@CIFAR10Configs.calc(CIFAR10Configs.train_loader)
159
def cifar10_train_loader(c: CIFAR10Configs):
160
return DataLoader(c.train_dataset,
161
batch_size=c.train_batch_size,
162
shuffle=c.train_loader_shuffle)
163
164
165
@CIFAR10Configs.calc(CIFAR10Configs.valid_loader)
166
def cifar10_valid_loader(c: CIFAR10Configs):
167
return DataLoader(c.valid_dataset,
168
batch_size=c.valid_batch_size,
169
shuffle=c.valid_loader_shuffle)
170
171
172
CIFAR10Configs.aggregate(CIFAR10Configs.dataset_name, 'CIFAR10',
173
(CIFAR10Configs.dataset_transforms, 'cifar10_transforms'),
174
(CIFAR10Configs.train_dataset, 'cifar10_train_dataset'),
175
(CIFAR10Configs.valid_dataset, 'cifar10_valid_dataset'),
176
(CIFAR10Configs.train_loader, 'cifar10_train_loader'),
177
(CIFAR10Configs.valid_loader, 'cifar10_valid_loader'))
178
179
180
class TextDataset:
181
itos: List[str]
182
stoi: Dict[str, int]
183
n_tokens: int
184
train: str
185
valid: str
186
standard_tokens: List[str] = []
187
188
@staticmethod
189
def load(path: PurePath):
190
with open(str(path), 'r') as f:
191
return f.read()
192
193
def __init__(self, path: PurePath, tokenizer: Callable, train: str, valid: str, test: str, *,
194
n_tokens: Optional[int] = None,
195
stoi: Optional[Dict[str, int]] = None,
196
itos: Optional[List[str]] = None):
197
self.test = test
198
self.valid = valid
199
self.train = train
200
self.tokenizer = tokenizer
201
self.path = path
202
203
if n_tokens or stoi or itos:
204
assert stoi and itos and n_tokens
205
self.n_tokens = n_tokens
206
self.stoi = stoi
207
self.itos = itos
208
else:
209
self.n_tokens = len(self.standard_tokens)
210
self.stoi = {t: i for i, t in enumerate(self.standard_tokens)}
211
212
with monit.section("Tokenize"):
213
tokens = self.tokenizer(self.train) + self.tokenizer(self.valid)
214
tokens = sorted(list(set(tokens)))
215
216
for t in monit.iterate("Build vocabulary", tokens):
217
self.stoi[t] = self.n_tokens
218
self.n_tokens += 1
219
220
self.itos = [''] * self.n_tokens
221
for t, n in self.stoi.items():
222
self.itos[n] = t
223
224
def text_to_i(self, text: str) -> torch.Tensor:
225
tokens = self.tokenizer(text)
226
return torch.tensor([self.stoi[s] for s in tokens if s in self.stoi], dtype=torch.long)
227
228
def __repr__(self):
229
return f'{len(self.train) / 1_000_000 :,.2f}M, {len(self.valid) / 1_000_000 :,.2f}M - {str(self.path)}'
230
231
232
class SequentialDataLoader(IterableDataset):
233
def __init__(self, *, text: str, dataset: TextDataset,
234
batch_size: int, seq_len: int):
235
self.seq_len = seq_len
236
data = dataset.text_to_i(text)
237
n_batch = data.shape[0] // batch_size
238
data = data.narrow(0, 0, n_batch * batch_size)
239
data = data.view(batch_size, -1).t().contiguous()
240
self.data = data
241
242
def __len__(self):
243
return self.data.shape[0] // self.seq_len
244
245
def __iter__(self):
246
self.idx = 0
247
return self
248
249
def __next__(self):
250
if self.idx >= self.data.shape[0] - 1:
251
raise StopIteration()
252
253
seq_len = min(self.seq_len, self.data.shape[0] - 1 - self.idx)
254
i = self.idx + seq_len
255
data = self.data[self.idx: i]
256
target = self.data[self.idx + 1: i + 1]
257
self.idx = i
258
return data, target
259
260
def __getitem__(self, idx):
261
seq_len = min(self.seq_len, self.data.shape[0] - 1 - idx)
262
i = idx + seq_len
263
data = self.data[idx: i]
264
target = self.data[idx + 1: i + 1]
265
return data, target
266
267
268
class SequentialUnBatchedDataset(Dataset):
269
def __init__(self, *, text: str, dataset: TextDataset,
270
seq_len: int,
271
is_random_offset: bool = True):
272
self.is_random_offset = is_random_offset
273
self.seq_len = seq_len
274
self.data = dataset.text_to_i(text)
275
276
def __len__(self):
277
return (self.data.shape[0] - 1) // self.seq_len
278
279
def __getitem__(self, idx):
280
start = idx * self.seq_len
281
assert start + self.seq_len + 1 <= self.data.shape[0]
282
if self.is_random_offset:
283
start += random.randint(0, min(self.seq_len - 1, self.data.shape[0] - (start + self.seq_len + 1)))
284
285
end = start + self.seq_len
286
data = self.data[start: end]
287
target = self.data[start + 1: end + 1]
288
return data, target
289
290
291
class TextFileDataset(TextDataset):
292
standard_tokens = []
293
294
def __init__(self, path: PurePath, tokenizer: Callable, *,
295
url: Optional[str] = None,
296
filter_subset: Optional[int] = None):
297
path = Path(path)
298
if not path.exists():
299
if not url:
300
raise FileNotFoundError(str(path))
301
else:
302
download_file(url, path)
303
304
with monit.section("Load data"):
305
text = self.load(path)
306
if filter_subset:
307
text = text[:filter_subset]
308
split = int(len(text) * .9)
309
train = text[:split]
310
valid = text[split:]
311
312
super().__init__(path, tokenizer, train, valid, '')
313
314
315
def _test_tiny_shakespeare():
316
from labml import lab
317
_ = TextFileDataset(lab.get_data_path() / 'tiny_shakespeare.txt', lambda x: list(x),
318
url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
319
320
321
if __name__ == '__main__':
322
_test_tiny_shakespeare()
323
324