Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/experiments/nlp_classification.py
4918 views
1
"""
2
---
3
title: NLP classification trainer
4
summary: >
5
This is a reusable trainer for classification tasks
6
---
7
8
# NLP model trainer for classification
9
"""
10
11
from collections import Counter
12
from typing import Callable
13
14
import torchtext
15
import torchtext.vocab
16
from torchtext.vocab import Vocab
17
18
import torch
19
from labml import lab, tracker, monit
20
from labml.configs import option
21
from labml_nn.helpers.device import DeviceConfigs
22
from labml_nn.helpers.metrics import Accuracy
23
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
24
from labml_nn.optimizers.configs import OptimizerConfigs
25
from torch import nn
26
from torch.utils.data import DataLoader
27
28
29
class NLPClassificationConfigs(TrainValidConfigs):
30
"""
31
<a id="NLPClassificationConfigs"></a>
32
33
## Trainer configurations
34
35
This has the basic configurations for NLP classification task training.
36
All the properties are configurable.
37
"""
38
39
# Optimizer
40
optimizer: torch.optim.Adam
41
# Training device
42
device: torch.device = DeviceConfigs()
43
44
# Autoregressive model
45
model: nn.Module
46
# Batch size
47
batch_size: int = 16
48
# Length of the sequence, or context size
49
seq_len: int = 512
50
# Vocabulary
51
vocab: Vocab = 'ag_news'
52
# Number of token in vocabulary
53
n_tokens: int
54
# Number of classes
55
n_classes: int = 'ag_news'
56
# Tokenizer
57
tokenizer: Callable = 'character'
58
59
# Whether to periodically save models
60
is_save_models = True
61
62
# Loss function
63
loss_func = nn.CrossEntropyLoss()
64
# Accuracy function
65
accuracy = Accuracy()
66
# Model embedding size
67
d_model: int = 512
68
# Gradient clipping
69
grad_norm_clip: float = 1.0
70
71
# Training data loader
72
train_loader: DataLoader = 'ag_news'
73
# Validation data loader
74
valid_loader: DataLoader = 'ag_news'
75
76
# Whether to log model parameters and gradients (once per epoch).
77
# These are summarized stats per layer, but it could still lead
78
# to many indicators for very deep networks.
79
is_log_model_params_grads: bool = False
80
81
# Whether to log model activations (once per epoch).
82
# These are summarized stats per layer, but it could still lead
83
# to many indicators for very deep networks.
84
is_log_model_activations: bool = False
85
86
def init(self):
87
"""
88
### Initialization
89
"""
90
# Set tracker configurations
91
tracker.set_scalar("accuracy.*", True)
92
tracker.set_scalar("loss.*", True)
93
# Add accuracy as a state module.
94
# The name is probably confusing, since it's meant to store
95
# states between training and validation for RNNs.
96
# This will keep the accuracy metric stats separate for training and validation.
97
self.state_modules = [self.accuracy]
98
99
def step(self, batch: any, batch_idx: BatchIndex):
100
"""
101
### Training or validation step
102
"""
103
104
# Move data to the device
105
data, target = batch[0].to(self.device), batch[1].to(self.device)
106
107
# Update global step (number of tokens processed) when in training mode
108
if self.mode.is_train:
109
tracker.add_global_step(data.shape[1])
110
111
# Get model outputs.
112
# It's returning a tuple for states when using RNNs.
113
# This is not implemented yet. 😜
114
output, *_ = self.model(data)
115
116
# Calculate and log loss
117
loss = self.loss_func(output, target)
118
tracker.add("loss.", loss)
119
120
# Calculate and log accuracy
121
self.accuracy(output, target)
122
self.accuracy.track()
123
124
# Train the model
125
if self.mode.is_train:
126
# Calculate gradients
127
loss.backward()
128
# Clip gradients
129
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
130
# Take optimizer step
131
self.optimizer.step()
132
# Log the model parameters and gradients on last batch of every epoch
133
if batch_idx.is_last and self.is_log_model_params_grads:
134
tracker.add('model', self.model)
135
# Clear the gradients
136
self.optimizer.zero_grad()
137
138
# Save the tracked metrics
139
tracker.save()
140
141
142
@option(NLPClassificationConfigs.optimizer)
143
def _optimizer(c: NLPClassificationConfigs):
144
"""
145
### Default [optimizer configurations](../optimizers/configs.html)
146
"""
147
148
optimizer = OptimizerConfigs()
149
optimizer.parameters = c.model.parameters()
150
optimizer.optimizer = 'Adam'
151
optimizer.d_model = c.d_model
152
153
return optimizer
154
155
156
@option(NLPClassificationConfigs.tokenizer)
157
def basic_english():
158
"""
159
### Basic english tokenizer
160
161
We use character level tokenizer in this experiment.
162
You can switch by setting,
163
164
```
165
'tokenizer': 'basic_english',
166
```
167
168
in the configurations dictionary when starting the experiment.
169
170
"""
171
from torchtext.data import get_tokenizer
172
return get_tokenizer('basic_english')
173
174
175
def character_tokenizer(x: str):
176
"""
177
### Character level tokenizer
178
"""
179
return list(x)
180
181
182
@option(NLPClassificationConfigs.tokenizer)
183
def character():
184
"""
185
Character level tokenizer configuration
186
"""
187
return character_tokenizer
188
189
190
@option(NLPClassificationConfigs.n_tokens)
191
def _n_tokens(c: NLPClassificationConfigs):
192
"""
193
Get number of tokens
194
"""
195
return len(c.vocab) + 2
196
197
198
class CollateFunc:
199
"""
200
## Function to load data into batches
201
"""
202
203
def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
204
"""
205
* `tokenizer` is the tokenizer function
206
* `vocab` is the vocabulary
207
* `seq_len` is the length of the sequence
208
* `padding_token` is the token used for padding when the `seq_len` is larger than the text length
209
* `classifier_token` is the `[CLS]` token which we set at end of the input
210
"""
211
self.classifier_token = classifier_token
212
self.padding_token = padding_token
213
self.seq_len = seq_len
214
self.vocab = vocab
215
self.tokenizer = tokenizer
216
217
def __call__(self, batch):
218
"""
219
* `batch` is the batch of data collected by the `DataLoader`
220
"""
221
222
# Input data tensor, initialized with `padding_token`
223
data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)
224
# Empty labels tensor
225
labels = torch.zeros(len(batch), dtype=torch.long)
226
227
# Loop through the samples
228
for (i, (_label, _text)) in enumerate(batch):
229
# Set the label
230
labels[i] = int(_label) - 1
231
# Tokenize the input text
232
_text = [self.vocab[token] for token in self.tokenizer(_text)]
233
# Truncate upto `seq_len`
234
_text = _text[:self.seq_len]
235
# Transpose and add to data
236
data[:len(_text), i] = data.new_tensor(_text)
237
238
# Set the final token in the sequence to `[CLS]`
239
data[-1, :] = self.classifier_token
240
241
#
242
return data, labels
243
244
245
@option([NLPClassificationConfigs.n_classes,
246
NLPClassificationConfigs.vocab,
247
NLPClassificationConfigs.train_loader,
248
NLPClassificationConfigs.valid_loader])
249
def ag_news(c: NLPClassificationConfigs):
250
"""
251
### AG News dataset
252
253
This loads the AG News dataset and the set the values for
254
`n_classes`, `vocab`, `train_loader`, and `valid_loader`.
255
"""
256
257
# Get training and validation datasets
258
train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))
259
260
# Load data to memory
261
with monit.section('Load data'):
262
from labml_nn.utils import MapStyleDataset
263
264
# Create [map-style datasets](../utils.html#map_style_dataset)
265
train, valid = MapStyleDataset(train), MapStyleDataset(valid)
266
267
# Get tokenizer
268
tokenizer = c.tokenizer
269
270
# Create a counter
271
counter = Counter()
272
# Collect tokens from training dataset
273
for (label, line) in train:
274
counter.update(tokenizer(line))
275
# Collect tokens from validation dataset
276
for (label, line) in valid:
277
counter.update(tokenizer(line))
278
# Create vocabulary
279
vocab = torchtext.vocab.vocab(counter, min_freq=1)
280
281
# Create training data loader
282
train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
283
collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
284
# Create validation data loader
285
valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
286
collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
287
288
# Return `n_classes`, `vocab`, `train_loader`, and `valid_loader`
289
return 4, vocab, train_loader, valid_loader
290
291