Path: blob/master/guides/md/keras_hub/transformer_pretraining.md
3297 views
Pretraining a Transformer from scratch with KerasHub
Author: Matthew Watson
Date created: 2022/04/18
Last modified: 2023/07/15
Description: Use KerasHub to train a Transformer model from scratch.
KerasHub aims to make it easy to build state-of-the-art text processing models. In this guide, we will show how library components simplify pretraining and fine-tuning a Transformer model from scratch.
This guide is broken into three parts:
Setup, task definition, and establishing a baseline.
Pretraining a Transformer model.
Fine-tuning the Transformer model on our classification task.
Setup
The following guide uses Keras 3 to work in any of tensorflow
, jax
or torch
. We select the jax
backend below, which will give us a particularly fast train step below, but feel free to mix it up.
SST-2 a text classification dataset and our "end goal". This dataset is often used to benchmark language models.
WikiText-103: A medium sized collection of featured articles from English Wikipedia, which we will use for pretraining.
Finally, we will download a WordPiece vocabulary, to do sub-word tokenization later on in this guide.
Next, we define some hyperparameters we will use during training.
Load data
We load our data with tf.data, which will allow us to define input pipelines for tokenizing and preprocessing text.
<keras.src.callbacks.history.History at 0x7f13902967a0>
Model: "functional_3"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ input_layer_1 (InputLayer) │ (None, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ token_and_position_embedding │ (None, 128, 256) │ 7,846,400 │ │ (TokenAndPositionEmbedding) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ layer_normalization │ (None, 128, 256) │ 512 │ │ (LayerNormalization) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout (Dropout) │ (None, 128, 256) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ transformer_encoder │ (None, 128, 256) │ 527,104 │ │ (TransformerEncoder) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ transformer_encoder_1 │ (None, 128, 256) │ 527,104 │ │ (TransformerEncoder) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ transformer_encoder_2 │ (None, 128, 256) │ 527,104 │ │ (TransformerEncoder) │ │ │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 9,428,224 (287.73 MB)
Trainable params: 9,428,224 (287.73 MB)
Non-trainable params: 0 (0.00 B)
Pretrain the Transformer
You can think of the encoder_model
as it's own modular unit, it is the piece of our model that we are really interested in for our downstream task. However we still need to set up the encoder to train on the MaskedLM task; to do that we attach a keras_hub.layers.MaskedLMHead
.
This layer will take as one input the token encodings, and as another the positions we masked out in the original input. It will gather the token encodings we masked, and transform them back in predictions over our entire vocabulary.
With that, we are ready to compile and run pretraining. If you are running this in a Colab, note that this will take about an hour. Training Transformer is famously compute intensive, so even this relatively small Transformer will take some time.
<keras.src.callbacks.history.History at 0x7f12d85c21a0>