Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_hub/transformer_pretraining.py
3293 views
1
"""
2
Title: Pretraining a Transformer from scratch with KerasHub
3
Author: [Matthew Watson](https://github.com/mattdangerw/)
4
Date created: 2022/04/18
5
Last modified: 2023/07/15
6
Description: Use KerasHub to train a Transformer model from scratch.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Anshuman Mishra](https://github.com/shivance)
9
"""
10
11
"""
12
KerasHub aims to make it easy to build state-of-the-art text processing models. In this
13
guide, we will show how library components simplify pretraining and fine-tuning a
14
Transformer model from scratch.
15
16
This guide is broken into three parts:
17
18
1. *Setup*, task definition, and establishing a baseline.
19
2. *Pretraining* a Transformer model.
20
3. *Fine-tuning* the Transformer model on our classification task.
21
"""
22
23
"""
24
## Setup
25
26
The following guide uses Keras 3 to work in any of `tensorflow`, `jax` or
27
`torch`. We select the `jax` backend below, which will give us a particularly
28
fast train step below, but feel free to mix it up.
29
"""
30
31
"""shell
32
pip install -q --upgrade keras-hub
33
pip install -q --upgrade keras # Upgrade to Keras 3.
34
"""
35
36
import os
37
38
os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"
39
40
41
import keras_hub
42
import tensorflow as tf
43
import keras
44
45
"""
46
Next up, we can download two datasets.
47
48
- [SST-2](https://paperswithcode.com/sota/sentiment-analysis-on-sst-2-binary) a text
49
classification dataset and our "end goal". This dataset is often used to benchmark
50
language models.
51
- [WikiText-103](https://paperswithcode.com/dataset/wikitext-103): A medium sized
52
collection of featured articles from English Wikipedia, which we will use for
53
pretraining.
54
55
Finally, we will download a WordPiece vocabulary, to do sub-word tokenization later on in
56
this guide.
57
"""
58
59
# Download pretraining data.
60
keras.utils.get_file(
61
origin="https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip",
62
extract=True,
63
)
64
wiki_dir = os.path.expanduser("~/.keras/datasets/wikitext-103-raw/")
65
66
# Download finetuning data.
67
keras.utils.get_file(
68
origin="https://dl.fbaipublicfiles.com/glue/data/SST-2.zip",
69
extract=True,
70
)
71
sst_dir = os.path.expanduser("~/.keras/datasets/SST-2/")
72
73
# Download vocabulary data.
74
vocab_file = keras.utils.get_file(
75
origin="https://storage.googleapis.com/tensorflow/keras-nlp/examples/bert/bert_vocab_uncased.txt",
76
)
77
78
"""
79
Next, we define some hyperparameters we will use during training.
80
"""
81
82
# Preprocessing params.
83
PRETRAINING_BATCH_SIZE = 128
84
FINETUNING_BATCH_SIZE = 32
85
SEQ_LENGTH = 128
86
MASK_RATE = 0.25
87
PREDICTIONS_PER_SEQ = 32
88
89
# Model params.
90
NUM_LAYERS = 3
91
MODEL_DIM = 256
92
INTERMEDIATE_DIM = 512
93
NUM_HEADS = 4
94
DROPOUT = 0.1
95
NORM_EPSILON = 1e-5
96
97
# Training params.
98
PRETRAINING_LEARNING_RATE = 5e-4
99
PRETRAINING_EPOCHS = 8
100
FINETUNING_LEARNING_RATE = 5e-5
101
FINETUNING_EPOCHS = 3
102
103
"""
104
### Load data
105
106
We load our data with [tf.data](https://www.tensorflow.org/guide/data), which will allow
107
us to define input pipelines for tokenizing and preprocessing text.
108
"""
109
110
# Load SST-2.
111
sst_train_ds = tf.data.experimental.CsvDataset(
112
sst_dir + "train.tsv", [tf.string, tf.int32], header=True, field_delim="\t"
113
).batch(FINETUNING_BATCH_SIZE)
114
sst_val_ds = tf.data.experimental.CsvDataset(
115
sst_dir + "dev.tsv", [tf.string, tf.int32], header=True, field_delim="\t"
116
).batch(FINETUNING_BATCH_SIZE)
117
118
# Load wikitext-103 and filter out short lines.
119
wiki_train_ds = (
120
tf.data.TextLineDataset(wiki_dir + "wiki.train.raw")
121
.filter(lambda x: tf.strings.length(x) > 100)
122
.batch(PRETRAINING_BATCH_SIZE)
123
)
124
wiki_val_ds = (
125
tf.data.TextLineDataset(wiki_dir + "wiki.valid.raw")
126
.filter(lambda x: tf.strings.length(x) > 100)
127
.batch(PRETRAINING_BATCH_SIZE)
128
)
129
130
# Take a peak at the sst-2 dataset.
131
print(sst_train_ds.unbatch().batch(4).take(1).get_single_element())
132
133
"""
134
You can see that our `SST-2` dataset contains relatively short snippets of movie review
135
text. Our goal is to predict the sentiment of the snippet. A label of 1 indicates
136
positive sentiment, and a label of 0 negative sentiment.
137
"""
138
139
"""
140
### Establish a baseline
141
142
As a first step, we will establish a baseline of good performance. We don't actually need
143
KerasHub for this, we can just use core Keras layers.
144
145
We will train a simple bag-of-words model, where we learn a positive or negative weight
146
for each word in our vocabulary. A sample's score is simply the sum of the weights of all
147
words that are present in the sample.
148
"""
149
150
# This layer will turn our input sentence into a list of 1s and 0s the same size
151
# our vocabulary, indicating whether a word is present in absent.
152
multi_hot_layer = keras.layers.TextVectorization(
153
max_tokens=4000, output_mode="multi_hot"
154
)
155
multi_hot_layer.adapt(sst_train_ds.map(lambda x, y: x))
156
multi_hot_ds = sst_train_ds.map(lambda x, y: (multi_hot_layer(x), y))
157
multi_hot_val_ds = sst_val_ds.map(lambda x, y: (multi_hot_layer(x), y))
158
159
# We then learn a linear regression over that layer, and that's our entire
160
# baseline model!
161
162
inputs = keras.Input(shape=(4000,), dtype="int32")
163
outputs = keras.layers.Dense(1, activation="sigmoid")(inputs)
164
baseline_model = keras.Model(inputs, outputs)
165
baseline_model.compile(loss="binary_crossentropy", metrics=["accuracy"])
166
baseline_model.fit(multi_hot_ds, validation_data=multi_hot_val_ds, epochs=5)
167
168
"""
169
A bag-of-words approach can be a fast and surprisingly powerful, especially when input
170
examples contain a large number of words. With shorter sequences, it can hit a
171
performance ceiling.
172
173
To do better, we would like to build a model that can evaluate words *in context*. Instead
174
of evaluating each word in a void, we need to use the information contained in the
175
*entire ordered sequence* of our input.
176
177
This runs us into a problem. `SST-2` is very small dataset, and there's simply not enough
178
example text to attempt to build a larger, more parameterized model that can learn on a
179
sequence. We would quickly start to overfit and memorize our training set, without any
180
increase in our ability to generalize to unseen examples.
181
182
Enter **pretraining**, which will allow us to learn on a larger corpus, and transfer our
183
knowledge to the `SST-2` task. And enter **KerasHub**, which will allow us to pretrain a
184
particularly powerful model, the Transformer, with ease.
185
"""
186
187
"""
188
## Pretraining
189
190
To beat our baseline, we will leverage the `WikiText103` dataset, an unlabeled
191
collection of Wikipedia articles that is much bigger than `SST-2`.
192
193
We are going to train a *transformer*, a highly expressive model which will learn
194
to embed each word in our input as a low dimensional vector. Our wikipedia dataset has no
195
labels, so we will use an unsupervised training objective called the *Masked Language
196
Modeling* (MaskedLM) objective.
197
198
Essentially, we will be playing a big game of "guess the missing word". For each input
199
sample we will obscure 25% of our input data, and train our model to predict the parts we
200
covered up.
201
"""
202
203
"""
204
### Preprocess data for the MaskedLM task
205
206
Our text preprocessing for the MaskedLM task will occur in two stages.
207
208
1. Tokenize input text into integer sequences of token ids.
209
2. Mask certain positions in our input to predict on.
210
211
To tokenize, we can use a `keras_hub.tokenizers.Tokenizer` -- the KerasHub building block
212
for transforming text into sequences of integer token ids.
213
214
In particular, we will use `keras_hub.tokenizers.WordPieceTokenizer` which does
215
*sub-word* tokenization. Sub-word tokenization is popular when training models on large
216
text corpora. Essentially, it allows our model to learn from uncommon words, while not
217
requiring a massive vocabulary of every word in our training set.
218
219
The second thing we need to do is mask our input for the MaskedLM task. To do this, we can use
220
`keras_hub.layers.MaskedLMMaskGenerator`, which will randomly select a set of tokens in each
221
input and mask them out.
222
223
The tokenizer and the masking layer can both be used inside a call to
224
[tf.data.Dataset.map](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map).
225
We can use `tf.data` to efficiently pre-compute each batch on the CPU, while our GPU or TPU
226
works on training with the batch that came before. Because our masking layer will
227
choose new words to mask each time, each epoch over our dataset will give us a totally
228
new set of labels to train on.
229
"""
230
231
# Setting sequence_length will trim or pad the token outputs to shape
232
# (batch_size, SEQ_LENGTH).
233
tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
234
vocabulary=vocab_file,
235
sequence_length=SEQ_LENGTH,
236
lowercase=True,
237
strip_accents=True,
238
)
239
# Setting mask_selection_length will trim or pad the mask outputs to shape
240
# (batch_size, PREDICTIONS_PER_SEQ).
241
masker = keras_hub.layers.MaskedLMMaskGenerator(
242
vocabulary_size=tokenizer.vocabulary_size(),
243
mask_selection_rate=MASK_RATE,
244
mask_selection_length=PREDICTIONS_PER_SEQ,
245
mask_token_id=tokenizer.token_to_id("[MASK]"),
246
)
247
248
249
def preprocess(inputs):
250
inputs = tokenizer(inputs)
251
outputs = masker(inputs)
252
# Split the masking layer outputs into a (features, labels, and weights)
253
# tuple that we can use with keras.Model.fit().
254
features = {
255
"token_ids": outputs["token_ids"],
256
"mask_positions": outputs["mask_positions"],
257
}
258
labels = outputs["mask_ids"]
259
weights = outputs["mask_weights"]
260
return features, labels, weights
261
262
263
# We use prefetch() to pre-compute preprocessed batches on the fly on the CPU.
264
pretrain_ds = wiki_train_ds.map(
265
preprocess, num_parallel_calls=tf.data.AUTOTUNE
266
).prefetch(tf.data.AUTOTUNE)
267
pretrain_val_ds = wiki_val_ds.map(
268
preprocess, num_parallel_calls=tf.data.AUTOTUNE
269
).prefetch(tf.data.AUTOTUNE)
270
271
# Preview a single input example.
272
# The masks will change each time you run the cell.
273
print(pretrain_val_ds.take(1).get_single_element())
274
275
"""
276
The above block sorts our dataset into a `(features, labels, weights)` tuple, which can be
277
passed directly to `keras.Model.fit()`.
278
279
We have two features:
280
281
1. `"token_ids"`, where some tokens have been replaced with our mask token id.
282
2. `"mask_positions"`, which keeps track of which tokens we masked out.
283
284
Our labels are simply the ids we masked out.
285
286
Because not all sequences will have the same number of masks, we also keep a
287
`sample_weight` tensor, which removes padded labels from our loss function by giving them
288
zero weight.
289
"""
290
291
"""
292
### Create the Transformer encoder
293
294
KerasHub provides all the building blocks to quickly build a Transformer encoder.
295
296
We use `keras_hub.layers.TokenAndPositionEmbedding` to first embed our input token ids.
297
This layer simultaneously learns two embeddings -- one for words in a sentence and another
298
for integer positions in a sentence. The output embedding is simply the sum of the two.
299
300
Then we can add a series of `keras_hub.layers.TransformerEncoder` layers. These are the
301
bread and butter of the Transformer model, using an attention mechanism to attend to
302
different parts of the input sentence, followed by a multi-layer perceptron block.
303
304
The output of this model will be a encoded vector per input token id. Unlike the
305
bag-of-words model we used as a baseline, this model will embed each token accounting for
306
the context in which it appeared.
307
"""
308
309
inputs = keras.Input(shape=(SEQ_LENGTH,), dtype="int32")
310
311
# Embed our tokens with a positional embedding.
312
embedding_layer = keras_hub.layers.TokenAndPositionEmbedding(
313
vocabulary_size=tokenizer.vocabulary_size(),
314
sequence_length=SEQ_LENGTH,
315
embedding_dim=MODEL_DIM,
316
)
317
outputs = embedding_layer(inputs)
318
319
# Apply layer normalization and dropout to the embedding.
320
outputs = keras.layers.LayerNormalization(epsilon=NORM_EPSILON)(outputs)
321
outputs = keras.layers.Dropout(rate=DROPOUT)(outputs)
322
323
# Add a number of encoder blocks
324
for i in range(NUM_LAYERS):
325
outputs = keras_hub.layers.TransformerEncoder(
326
intermediate_dim=INTERMEDIATE_DIM,
327
num_heads=NUM_HEADS,
328
dropout=DROPOUT,
329
layer_norm_epsilon=NORM_EPSILON,
330
)(outputs)
331
332
encoder_model = keras.Model(inputs, outputs)
333
encoder_model.summary()
334
335
"""
336
### Pretrain the Transformer
337
338
You can think of the `encoder_model` as it's own modular unit, it is the piece of our
339
model that we are really interested in for our downstream task. However we still need to
340
set up the encoder to train on the MaskedLM task; to do that we attach a
341
`keras_hub.layers.MaskedLMHead`.
342
343
This layer will take as one input the token encodings, and as another the positions we
344
masked out in the original input. It will gather the token encodings we masked, and
345
transform them back in predictions over our entire vocabulary.
346
347
With that, we are ready to compile and run pretraining. If you are running this in a
348
Colab, note that this will take about an hour. Training Transformer is famously compute
349
intensive, so even this relatively small Transformer will take some time.
350
"""
351
352
# Create the pretraining model by attaching a masked language model head.
353
inputs = {
354
"token_ids": keras.Input(shape=(SEQ_LENGTH,), dtype="int32", name="token_ids"),
355
"mask_positions": keras.Input(
356
shape=(PREDICTIONS_PER_SEQ,), dtype="int32", name="mask_positions"
357
),
358
}
359
360
# Encode the tokens.
361
encoded_tokens = encoder_model(inputs["token_ids"])
362
363
# Predict an output word for each masked input token.
364
# We use the input token embedding to project from our encoded vectors to
365
# vocabulary logits, which has been shown to improve training efficiency.
366
outputs = keras_hub.layers.MaskedLMHead(
367
token_embedding=embedding_layer.token_embedding,
368
activation="softmax",
369
)(encoded_tokens, mask_positions=inputs["mask_positions"])
370
371
# Define and compile our pretraining model.
372
pretraining_model = keras.Model(inputs, outputs)
373
pretraining_model.compile(
374
loss="sparse_categorical_crossentropy",
375
optimizer=keras.optimizers.AdamW(PRETRAINING_LEARNING_RATE),
376
weighted_metrics=["sparse_categorical_accuracy"],
377
jit_compile=True,
378
)
379
380
# Pretrain the model on our wiki text dataset.
381
pretraining_model.fit(
382
pretrain_ds,
383
validation_data=pretrain_val_ds,
384
epochs=PRETRAINING_EPOCHS,
385
)
386
387
# Save this base model for further finetuning.
388
encoder_model.save("encoder_model.keras")
389
390
"""
391
## Fine-tuning
392
393
After pretraining, we can now fine-tune our model on the `SST-2` dataset. We can
394
leverage the ability of the encoder we build to predict on words in context to boost
395
our performance on the downstream task.
396
"""
397
398
"""
399
### Preprocess data for classification
400
401
Preprocessing for fine-tuning is much simpler than for our pretraining MaskedLM task. We just
402
tokenize our input sentences and we are ready for training!
403
"""
404
405
406
def preprocess(sentences, labels):
407
return tokenizer(sentences), labels
408
409
410
# We use prefetch() to pre-compute preprocessed batches on the fly on our CPU.
411
finetune_ds = sst_train_ds.map(
412
preprocess, num_parallel_calls=tf.data.AUTOTUNE
413
).prefetch(tf.data.AUTOTUNE)
414
finetune_val_ds = sst_val_ds.map(
415
preprocess, num_parallel_calls=tf.data.AUTOTUNE
416
).prefetch(tf.data.AUTOTUNE)
417
418
# Preview a single input example.
419
print(finetune_val_ds.take(1).get_single_element())
420
421
"""
422
### Fine-tune the Transformer
423
424
To go from our encoded token output to a classification prediction, we need to attach
425
another "head" to our Transformer model. We can afford to be simple here. We pool
426
the encoded tokens together, and use a single dense layer to make a prediction.
427
"""
428
429
# Reload the encoder model from disk so we can restart fine-tuning from scratch.
430
encoder_model = keras.models.load_model("encoder_model.keras", compile=False)
431
432
# Take as input the tokenized input.
433
inputs = keras.Input(shape=(SEQ_LENGTH,), dtype="int32")
434
435
# Encode and pool the tokens.
436
encoded_tokens = encoder_model(inputs)
437
pooled_tokens = keras.layers.GlobalAveragePooling1D()(encoded_tokens[0])
438
439
# Predict an output label.
440
outputs = keras.layers.Dense(1, activation="sigmoid")(pooled_tokens)
441
442
# Define and compile our fine-tuning model.
443
finetuning_model = keras.Model(inputs, outputs)
444
finetuning_model.compile(
445
loss="binary_crossentropy",
446
optimizer=keras.optimizers.AdamW(FINETUNING_LEARNING_RATE),
447
metrics=["accuracy"],
448
)
449
450
# Finetune the model for the SST-2 task.
451
finetuning_model.fit(
452
finetune_ds,
453
validation_data=finetune_val_ds,
454
epochs=FINETUNING_EPOCHS,
455
)
456
457
"""
458
Pretraining was enough to boost our performance to 84%, and this is hardly the ceiling
459
for Transformer models. You may have noticed during pretraining that our validation
460
performance was still steadily increasing. Our model is still significantly undertrained.
461
Training for more epochs, training a large Transformer, and training on more unlabeled
462
text would all continue to boost performance significantly.
463
464
One of the key goals of KerasHub is to provide a modular approach to NLP model building.
465
We have shown one approach to building a Transformer here, but KerasHub supports an ever
466
growing array of components for preprocessing text and building models. We hope it makes
467
it easier to experiment on solutions to your natural language problems.
468
"""
469
470