Path: blob/master/guides/keras_hub/transformer_pretraining.py
3293 views
"""1Title: Pretraining a Transformer from scratch with KerasHub2Author: [Matthew Watson](https://github.com/mattdangerw/)3Date created: 2022/04/184Last modified: 2023/07/155Description: Use KerasHub to train a Transformer model from scratch.6Accelerator: GPU7Converted to Keras 3 by: [Anshuman Mishra](https://github.com/shivance)8"""910"""11KerasHub aims to make it easy to build state-of-the-art text processing models. In this12guide, we will show how library components simplify pretraining and fine-tuning a13Transformer model from scratch.1415This guide is broken into three parts:16171. *Setup*, task definition, and establishing a baseline.182. *Pretraining* a Transformer model.193. *Fine-tuning* the Transformer model on our classification task.20"""2122"""23## Setup2425The following guide uses Keras 3 to work in any of `tensorflow`, `jax` or26`torch`. We select the `jax` backend below, which will give us a particularly27fast train step below, but feel free to mix it up.28"""2930"""shell31pip install -q --upgrade keras-hub32pip install -q --upgrade keras # Upgrade to Keras 3.33"""3435import os3637os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"383940import keras_hub41import tensorflow as tf42import keras4344"""45Next up, we can download two datasets.4647- [SST-2](https://paperswithcode.com/sota/sentiment-analysis-on-sst-2-binary) a text48classification dataset and our "end goal". This dataset is often used to benchmark49language models.50- [WikiText-103](https://paperswithcode.com/dataset/wikitext-103): A medium sized51collection of featured articles from English Wikipedia, which we will use for52pretraining.5354Finally, we will download a WordPiece vocabulary, to do sub-word tokenization later on in55this guide.56"""5758# Download pretraining data.59keras.utils.get_file(60origin="https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip",61extract=True,62)63wiki_dir = os.path.expanduser("~/.keras/datasets/wikitext-103-raw/")6465# Download finetuning data.66keras.utils.get_file(67origin="https://dl.fbaipublicfiles.com/glue/data/SST-2.zip",68extract=True,69)70sst_dir = os.path.expanduser("~/.keras/datasets/SST-2/")7172# Download vocabulary data.73vocab_file = keras.utils.get_file(74origin="https://storage.googleapis.com/tensorflow/keras-nlp/examples/bert/bert_vocab_uncased.txt",75)7677"""78Next, we define some hyperparameters we will use during training.79"""8081# Preprocessing params.82PRETRAINING_BATCH_SIZE = 12883FINETUNING_BATCH_SIZE = 3284SEQ_LENGTH = 12885MASK_RATE = 0.2586PREDICTIONS_PER_SEQ = 328788# Model params.89NUM_LAYERS = 390MODEL_DIM = 25691INTERMEDIATE_DIM = 51292NUM_HEADS = 493DROPOUT = 0.194NORM_EPSILON = 1e-59596# Training params.97PRETRAINING_LEARNING_RATE = 5e-498PRETRAINING_EPOCHS = 899FINETUNING_LEARNING_RATE = 5e-5100FINETUNING_EPOCHS = 3101102"""103### Load data104105We load our data with [tf.data](https://www.tensorflow.org/guide/data), which will allow106us to define input pipelines for tokenizing and preprocessing text.107"""108109# Load SST-2.110sst_train_ds = tf.data.experimental.CsvDataset(111sst_dir + "train.tsv", [tf.string, tf.int32], header=True, field_delim="\t"112).batch(FINETUNING_BATCH_SIZE)113sst_val_ds = tf.data.experimental.CsvDataset(114sst_dir + "dev.tsv", [tf.string, tf.int32], header=True, field_delim="\t"115).batch(FINETUNING_BATCH_SIZE)116117# Load wikitext-103 and filter out short lines.118wiki_train_ds = (119tf.data.TextLineDataset(wiki_dir + "wiki.train.raw")120.filter(lambda x: tf.strings.length(x) > 100)121.batch(PRETRAINING_BATCH_SIZE)122)123wiki_val_ds = (124tf.data.TextLineDataset(wiki_dir + "wiki.valid.raw")125.filter(lambda x: tf.strings.length(x) > 100)126.batch(PRETRAINING_BATCH_SIZE)127)128129# Take a peak at the sst-2 dataset.130print(sst_train_ds.unbatch().batch(4).take(1).get_single_element())131132"""133You can see that our `SST-2` dataset contains relatively short snippets of movie review134text. Our goal is to predict the sentiment of the snippet. A label of 1 indicates135positive sentiment, and a label of 0 negative sentiment.136"""137138"""139### Establish a baseline140141As a first step, we will establish a baseline of good performance. We don't actually need142KerasHub for this, we can just use core Keras layers.143144We will train a simple bag-of-words model, where we learn a positive or negative weight145for each word in our vocabulary. A sample's score is simply the sum of the weights of all146words that are present in the sample.147"""148149# This layer will turn our input sentence into a list of 1s and 0s the same size150# our vocabulary, indicating whether a word is present in absent.151multi_hot_layer = keras.layers.TextVectorization(152max_tokens=4000, output_mode="multi_hot"153)154multi_hot_layer.adapt(sst_train_ds.map(lambda x, y: x))155multi_hot_ds = sst_train_ds.map(lambda x, y: (multi_hot_layer(x), y))156multi_hot_val_ds = sst_val_ds.map(lambda x, y: (multi_hot_layer(x), y))157158# We then learn a linear regression over that layer, and that's our entire159# baseline model!160161inputs = keras.Input(shape=(4000,), dtype="int32")162outputs = keras.layers.Dense(1, activation="sigmoid")(inputs)163baseline_model = keras.Model(inputs, outputs)164baseline_model.compile(loss="binary_crossentropy", metrics=["accuracy"])165baseline_model.fit(multi_hot_ds, validation_data=multi_hot_val_ds, epochs=5)166167"""168A bag-of-words approach can be a fast and surprisingly powerful, especially when input169examples contain a large number of words. With shorter sequences, it can hit a170performance ceiling.171172To do better, we would like to build a model that can evaluate words *in context*. Instead173of evaluating each word in a void, we need to use the information contained in the174*entire ordered sequence* of our input.175176This runs us into a problem. `SST-2` is very small dataset, and there's simply not enough177example text to attempt to build a larger, more parameterized model that can learn on a178sequence. We would quickly start to overfit and memorize our training set, without any179increase in our ability to generalize to unseen examples.180181Enter **pretraining**, which will allow us to learn on a larger corpus, and transfer our182knowledge to the `SST-2` task. And enter **KerasHub**, which will allow us to pretrain a183particularly powerful model, the Transformer, with ease.184"""185186"""187## Pretraining188189To beat our baseline, we will leverage the `WikiText103` dataset, an unlabeled190collection of Wikipedia articles that is much bigger than `SST-2`.191192We are going to train a *transformer*, a highly expressive model which will learn193to embed each word in our input as a low dimensional vector. Our wikipedia dataset has no194labels, so we will use an unsupervised training objective called the *Masked Language195Modeling* (MaskedLM) objective.196197Essentially, we will be playing a big game of "guess the missing word". For each input198sample we will obscure 25% of our input data, and train our model to predict the parts we199covered up.200"""201202"""203### Preprocess data for the MaskedLM task204205Our text preprocessing for the MaskedLM task will occur in two stages.2062071. Tokenize input text into integer sequences of token ids.2082. Mask certain positions in our input to predict on.209210To tokenize, we can use a `keras_hub.tokenizers.Tokenizer` -- the KerasHub building block211for transforming text into sequences of integer token ids.212213In particular, we will use `keras_hub.tokenizers.WordPieceTokenizer` which does214*sub-word* tokenization. Sub-word tokenization is popular when training models on large215text corpora. Essentially, it allows our model to learn from uncommon words, while not216requiring a massive vocabulary of every word in our training set.217218The second thing we need to do is mask our input for the MaskedLM task. To do this, we can use219`keras_hub.layers.MaskedLMMaskGenerator`, which will randomly select a set of tokens in each220input and mask them out.221222The tokenizer and the masking layer can both be used inside a call to223[tf.data.Dataset.map](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map).224We can use `tf.data` to efficiently pre-compute each batch on the CPU, while our GPU or TPU225works on training with the batch that came before. Because our masking layer will226choose new words to mask each time, each epoch over our dataset will give us a totally227new set of labels to train on.228"""229230# Setting sequence_length will trim or pad the token outputs to shape231# (batch_size, SEQ_LENGTH).232tokenizer = keras_hub.tokenizers.WordPieceTokenizer(233vocabulary=vocab_file,234sequence_length=SEQ_LENGTH,235lowercase=True,236strip_accents=True,237)238# Setting mask_selection_length will trim or pad the mask outputs to shape239# (batch_size, PREDICTIONS_PER_SEQ).240masker = keras_hub.layers.MaskedLMMaskGenerator(241vocabulary_size=tokenizer.vocabulary_size(),242mask_selection_rate=MASK_RATE,243mask_selection_length=PREDICTIONS_PER_SEQ,244mask_token_id=tokenizer.token_to_id("[MASK]"),245)246247248def preprocess(inputs):249inputs = tokenizer(inputs)250outputs = masker(inputs)251# Split the masking layer outputs into a (features, labels, and weights)252# tuple that we can use with keras.Model.fit().253features = {254"token_ids": outputs["token_ids"],255"mask_positions": outputs["mask_positions"],256}257labels = outputs["mask_ids"]258weights = outputs["mask_weights"]259return features, labels, weights260261262# We use prefetch() to pre-compute preprocessed batches on the fly on the CPU.263pretrain_ds = wiki_train_ds.map(264preprocess, num_parallel_calls=tf.data.AUTOTUNE265).prefetch(tf.data.AUTOTUNE)266pretrain_val_ds = wiki_val_ds.map(267preprocess, num_parallel_calls=tf.data.AUTOTUNE268).prefetch(tf.data.AUTOTUNE)269270# Preview a single input example.271# The masks will change each time you run the cell.272print(pretrain_val_ds.take(1).get_single_element())273274"""275The above block sorts our dataset into a `(features, labels, weights)` tuple, which can be276passed directly to `keras.Model.fit()`.277278We have two features:2792801. `"token_ids"`, where some tokens have been replaced with our mask token id.2812. `"mask_positions"`, which keeps track of which tokens we masked out.282283Our labels are simply the ids we masked out.284285Because not all sequences will have the same number of masks, we also keep a286`sample_weight` tensor, which removes padded labels from our loss function by giving them287zero weight.288"""289290"""291### Create the Transformer encoder292293KerasHub provides all the building blocks to quickly build a Transformer encoder.294295We use `keras_hub.layers.TokenAndPositionEmbedding` to first embed our input token ids.296This layer simultaneously learns two embeddings -- one for words in a sentence and another297for integer positions in a sentence. The output embedding is simply the sum of the two.298299Then we can add a series of `keras_hub.layers.TransformerEncoder` layers. These are the300bread and butter of the Transformer model, using an attention mechanism to attend to301different parts of the input sentence, followed by a multi-layer perceptron block.302303The output of this model will be a encoded vector per input token id. Unlike the304bag-of-words model we used as a baseline, this model will embed each token accounting for305the context in which it appeared.306"""307308inputs = keras.Input(shape=(SEQ_LENGTH,), dtype="int32")309310# Embed our tokens with a positional embedding.311embedding_layer = keras_hub.layers.TokenAndPositionEmbedding(312vocabulary_size=tokenizer.vocabulary_size(),313sequence_length=SEQ_LENGTH,314embedding_dim=MODEL_DIM,315)316outputs = embedding_layer(inputs)317318# Apply layer normalization and dropout to the embedding.319outputs = keras.layers.LayerNormalization(epsilon=NORM_EPSILON)(outputs)320outputs = keras.layers.Dropout(rate=DROPOUT)(outputs)321322# Add a number of encoder blocks323for i in range(NUM_LAYERS):324outputs = keras_hub.layers.TransformerEncoder(325intermediate_dim=INTERMEDIATE_DIM,326num_heads=NUM_HEADS,327dropout=DROPOUT,328layer_norm_epsilon=NORM_EPSILON,329)(outputs)330331encoder_model = keras.Model(inputs, outputs)332encoder_model.summary()333334"""335### Pretrain the Transformer336337You can think of the `encoder_model` as it's own modular unit, it is the piece of our338model that we are really interested in for our downstream task. However we still need to339set up the encoder to train on the MaskedLM task; to do that we attach a340`keras_hub.layers.MaskedLMHead`.341342This layer will take as one input the token encodings, and as another the positions we343masked out in the original input. It will gather the token encodings we masked, and344transform them back in predictions over our entire vocabulary.345346With that, we are ready to compile and run pretraining. If you are running this in a347Colab, note that this will take about an hour. Training Transformer is famously compute348intensive, so even this relatively small Transformer will take some time.349"""350351# Create the pretraining model by attaching a masked language model head.352inputs = {353"token_ids": keras.Input(shape=(SEQ_LENGTH,), dtype="int32", name="token_ids"),354"mask_positions": keras.Input(355shape=(PREDICTIONS_PER_SEQ,), dtype="int32", name="mask_positions"356),357}358359# Encode the tokens.360encoded_tokens = encoder_model(inputs["token_ids"])361362# Predict an output word for each masked input token.363# We use the input token embedding to project from our encoded vectors to364# vocabulary logits, which has been shown to improve training efficiency.365outputs = keras_hub.layers.MaskedLMHead(366token_embedding=embedding_layer.token_embedding,367activation="softmax",368)(encoded_tokens, mask_positions=inputs["mask_positions"])369370# Define and compile our pretraining model.371pretraining_model = keras.Model(inputs, outputs)372pretraining_model.compile(373loss="sparse_categorical_crossentropy",374optimizer=keras.optimizers.AdamW(PRETRAINING_LEARNING_RATE),375weighted_metrics=["sparse_categorical_accuracy"],376jit_compile=True,377)378379# Pretrain the model on our wiki text dataset.380pretraining_model.fit(381pretrain_ds,382validation_data=pretrain_val_ds,383epochs=PRETRAINING_EPOCHS,384)385386# Save this base model for further finetuning.387encoder_model.save("encoder_model.keras")388389"""390## Fine-tuning391392After pretraining, we can now fine-tune our model on the `SST-2` dataset. We can393leverage the ability of the encoder we build to predict on words in context to boost394our performance on the downstream task.395"""396397"""398### Preprocess data for classification399400Preprocessing for fine-tuning is much simpler than for our pretraining MaskedLM task. We just401tokenize our input sentences and we are ready for training!402"""403404405def preprocess(sentences, labels):406return tokenizer(sentences), labels407408409# We use prefetch() to pre-compute preprocessed batches on the fly on our CPU.410finetune_ds = sst_train_ds.map(411preprocess, num_parallel_calls=tf.data.AUTOTUNE412).prefetch(tf.data.AUTOTUNE)413finetune_val_ds = sst_val_ds.map(414preprocess, num_parallel_calls=tf.data.AUTOTUNE415).prefetch(tf.data.AUTOTUNE)416417# Preview a single input example.418print(finetune_val_ds.take(1).get_single_element())419420"""421### Fine-tune the Transformer422423To go from our encoded token output to a classification prediction, we need to attach424another "head" to our Transformer model. We can afford to be simple here. We pool425the encoded tokens together, and use a single dense layer to make a prediction.426"""427428# Reload the encoder model from disk so we can restart fine-tuning from scratch.429encoder_model = keras.models.load_model("encoder_model.keras", compile=False)430431# Take as input the tokenized input.432inputs = keras.Input(shape=(SEQ_LENGTH,), dtype="int32")433434# Encode and pool the tokens.435encoded_tokens = encoder_model(inputs)436pooled_tokens = keras.layers.GlobalAveragePooling1D()(encoded_tokens[0])437438# Predict an output label.439outputs = keras.layers.Dense(1, activation="sigmoid")(pooled_tokens)440441# Define and compile our fine-tuning model.442finetuning_model = keras.Model(inputs, outputs)443finetuning_model.compile(444loss="binary_crossentropy",445optimizer=keras.optimizers.AdamW(FINETUNING_LEARNING_RATE),446metrics=["accuracy"],447)448449# Finetune the model for the SST-2 task.450finetuning_model.fit(451finetune_ds,452validation_data=finetune_val_ds,453epochs=FINETUNING_EPOCHS,454)455456"""457Pretraining was enough to boost our performance to 84%, and this is hardly the ceiling458for Transformer models. You may have noticed during pretraining that our validation459performance was still steadily increasing. Our model is still significantly undertrained.460Training for more epochs, training a large Transformer, and training on more unlabeled461text would all continue to boost performance significantly.462463One of the key goals of KerasHub is to provide a modular approach to NLP model building.464We have shown one approach to building a Transformer here, but KerasHub supports an ever465growing array of components for preprocessing text and building models. We hope it makes466it easier to experiment on solutions to your natural language problems.467"""468469470