Path: blob/master/examples/nlp/ipynb/parameter_efficient_finetuning_of_gpt2_with_lora.ipynb
3236 views
Parameter-efficient fine-tuning of GPT-2 with LoRA
Author: Abheesht Sharma, Matthew Watson
Date created: 2023/05/27
Last modified: 2023/05/27
Description: Use KerasHub to fine-tune a GPT-2 LLM with LoRA.
Introduction
Large Language Models (LLMs) have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned on a downstream task of interest (such as sentiment analysis).
However, LLMs are extremely large in size, and we don't need to train all the parameters in the model while fine-tuning, especially because datasets on which the model is fine-tuned are relatively small. Another way of saying this is that LLMs are over-parametrized for fine-tuning. This is where Low-Rank Adaptation (LoRA) comes in; it significantly reduces the number of trainable parameters. This results in a decrease in training time and GPU memory usage, while maintaining the quality of the outputs.
In this example, we will explain LoRA in technical terms, show how the technical explanation translates to code, hack KerasHub's GPT-2 model and fine-tune it on the next token prediction task using LoRA. We will compare LoRA GPT-2 with a fully fine-tuned GPT-2 in terms of the quality of the generated text, training time and GPU memory usage.
Note: This example runs on the TensorFlow backend purely for the tf.config.experimental.get_memory_info
API to easily plot memory usage. Outside of the memory usage callback, this example will run on jax
and torch
backends.
Setup
Before we start implementing the pipeline, let's install and import all the libraries we need. We'll be using the KerasHub library.
Secondly, let's enable mixed precision training. This will help us reduce the training time.
Let's also define our hyperparameters.
Dataset
Let's load a Reddit dataset. We will fine-tune both the GPT-2 model and the LoRA GPT-2 model on a subset of this dataset. The aim is to produce text similar in style to Reddit posts.
The dataset has two fields: document
and title
.
We'll now batch the dataset and retain only the document
field because we are fine-tuning the model on the next word prediction task. Take a subset of the dataset for the purpose of this example.
Helper functions
Before we begin fine-tuning the models, let's define a few helper functions and classes.
Callback for tracking GPU memory usage
We'll define a custom callback function which tracks GPU memory usage. The callback function uses TensorFlow's tf.config.experimental.get_memory_info
API.
Here, we assume that we are using a single GPU, GPU:0
.
Function for text generation
Here is a helper function to generate text.
Define optimizer and loss
We will use AdamW optimizer and cross-entropy loss for training both models.
Fine-tune GPT-2
Let's load the model and preprocessor first. We use a sequence length of 128 instead of 1024 (which is the default sequence length). This will limit our ability to predict long sequences, but will allow us to run this example quickly on Colab.
Initialize the GPU memory tracker callback object, and compile the model. We use the Adam optimizer with a linearly decaying learning rate.
We are all set to train the model!
As a final step, let's generate some text. We will harness the power of XLA. The first call to generate()
will be slow because of XLA compilation, but subsequent calls will be super-fast. 😃
LoRA GPT-2
In this section, we discuss the technical details of LoRA, build a LoRA GPT-2 model, fine-tune it and generate text.
What exactly is LoRA?
LoRA is a parameter-efficient fine-tuning technique for LLMs. It freezes the weights of the LLM, and injects trainable rank-decomposition matrices. Let's understand this more clearly.
Assume we have an n x n
pre-trained dense layer (or weight matrix), W0
. We initialize two dense layers, A
and B
, of shapes n x rank
, and rank x n
, respectively. rank
is much smaller than n
. In the paper, values between 1 and 4 are shown to work well.
LoRA equation
The original equation is output = W0x + b0
, where x
is the input, W0
and b0
are the weight matrix and bias terms of the original dense layer (frozen). The LoRA equation is: output = W0x + b0 + BAx
, where A
and B
are the rank-decomposition matrices.
LoRA is based on the idea that updates to the weights of the pre-trained language model have a low "intrinsic rank" since pre-trained language models are over-parametrized. Predictive performance of full fine-tuning can be replicated even by constraining W0
's updates to low-rank decomposition matrices.
Number of trainable parameters
Let's do some quick math. Suppose n
is 768, and rank
is 4. W0
has 768 x 768 = 589,824
parameters, whereas the LoRA layers, A
and B
together have 768 x 4 + 4 x 768 = 6,144
parameters. So, for the dense layer, we go from 589,824
trainable parameters to 6,144
trainable parameters!
Why does LoRA reduce memory footprint?
Even though the total number of parameters increase (since we are adding LoRA layers), the memory footprint reduces, because the number of trainable parameters reduces. Let's dive deeper into this.
The memory usage of a model can be split into four parts:
Model memory: This is the memory required to store the model weights. This will be slightly higher for LoRA than GPT-2.
Forward pass memory: This mostly depends on batch size, sequence length, etc. We keep this constant for both models for a fair comparison.
Backward pass memory: This is the memory required to store the gradients. Note that the gradients are computed only for the trainable parameters.
Optimizer memory: This is the memory required to store the optimizer state. For example, the Adam optimizer stores the "1st moment vectors" and "2nd moment vectors" for the trainable parameters.
Since, with LoRA, there is a huge reduction in the number of trainable parameters, the optimizer memory and the memory required to store the gradients for LoRA is much less than GPT-2. This is where most of the memory savings happen.
Why is LoRA so popular?
Reduces GPU memory usage;
Faster training; and
No additional inference latency.
Create LoRA layer
According to the technical description above, let's create a LoRA layer. In a transformer model, the LoRA layer is created and injected for the query and value projection matrices. In keras.layers.MultiHeadAttention
, the query/value projection layers are keras.layers.EinsumDense
layers.
Inject LoRA layer into the model
We will now hack the original GPT-2 model and inject LoRA layers into it. Let's do a couple of things before doing that:
Delete previous model;
Reset "peak" GPU memory usage using
tf.config.experimental.reset_memory_stats
;Load a new GPT-2 model.
We will now override the original query/value projection matrices with our new LoRA layers.
Let's now do a forward pass to make sure we still have a valid chain of computation.
Freeze the entire LLM, only the LoRA layers should be trainable.
Print the model's summary and see if the number of non-trainable parameters and total parameters are correct.
In a previous section, we had calculated the number of parameters associated with the LoRA layers to be 6,144. The total trainable parameters in the model should be num_layers * (query, value) * 6,144 = 12 * 2 * 6,144 = 147,456
. The number of non-trainable parameters should be the same as the total number of parameters in the original GPT-2 model, which is 124,439,808
.
Fine-tune LoRA GPT-2
Now that we have hacked and verified the LoRA GPT-2 model, let's train it!
And we are done fine-tuning the model! Before we generate text, let's compare the training time and memory usage of the two models. The training time of GPT-2 on a 16 GB Tesla T4 (Colab) is 7 minutes, and for LoRA, it is 5 minutes, a 30% decrease. The memory usage of LoRA GPT-2 is roughly 35% times less than GPT-2.
Merge weights and generate text!
One of the biggest advantages of LoRA over other adapter methods is that it does not incur any additional inference latency. Let's understand why.
Recall our LoRA equation: output = W0x + b0 + BAx
. We can rewrite this as: output = = Wx + b0 = (W0 + BA)x + b0
, where W = W0 + BA
. This means that if we merge the weights of the original model and the adapter, we will be essentially doing the same computation as the original model!
We are now all set to generate text with our LoRA model 😃.
And we're all done!