Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/examples/causal_language_modeling_flax.ipynb
Views: 2535
Pre-Training a 🤗 Transformers model on TPU with Flax/JAX
In this notebook, we will see how to pretrain one of the 🤗 Transformers models on TPU using Flax.
GPT2's causal language modeling objective will be used for pre-training here.
As can be seen on this benchmark using Flax/JAX on GPU/TPU is often much faster and can also be considerably cheaper than using PyTorch on GPU/TPU.
Flax is a high-performance neural network library designed for flexibility built on top of JAX (see below). It aims to provide users with full control of their training code and is carefully designed to work well with JAX transformations such as grad
and pmap
(see the Flax philosophy). For an introduction to Flax see the Flax Basic Colab or the list of curated Flax examples.
JAX is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more. A great place for getting started with JAX is the JAX 101 Tutorial.
You also will need to set up the TPU for JAX in this notebook. This can be done by executing the following lines.
If everything is set up correctly, the following command should return a list of 8 TPU devices.
In this notebook, we will pre-train an autoregressive model on one of the languages of the OSCAR corpus. OSCAR is a huge multilingual corpus obtained by language classification and filtering of the Common Crawl corpus using the goclassy architecture.
Let's first select the language that our model should learn. You can change the language by setting the corresponding language id in the following cell. The language ids can be found under the "File deduplicated" column on the official OSCAR website.
Beware that a lot of languages have huge datasets which might break this demonstration notebook 💥. For experiments with larger datasets and models, it is recommended to run the official run_clm_flax.py
script offline that can be found here.
Here we select is
for Icelandic 🇮🇸.
Next, we select the model architecture to be trained from scratch. Here we choose distilgpt2
, but essentially any auto-regressive model that is available on the 🤗 hub in JAX/Flax can be used.
We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.
1. Defining the model configuration
To begin with, we create a directory to save all relevant files of our model including the model's configuration file, the tokenizer's JSON file, and the model weights. We call the directory "distilgpt2-base-pretrained-is"
:
and create it:
Next, we'll download the model configuration:
and save it to the directory:
2. Training a tokenizer from scratch
One has to pre-process the raw text data to a format that is understandable by the model. In NLP, the de-facto standard is to use a tokenizer to pre-process data as explained here.
We can leverage the blazing-fast 🤗 Tokenizer library to train a ByteLevelBPETokenizer from scratch.
Let's import the necessary building blocks from tokenizers
and the load_dataset
function.
We will store our tokenizer files and model files in a directory, called model_dir
. We can load our chosen dataset conveniently using the load_dataset
function.
Having imported the ByteLevelBPETokenizer
, we instantiate it,
define a training iterator,
and train the tokenizer by defining vocab_size
according to our model's configuration along with the min_frequency
as well as some special_tokens
:
Finally, we save the trained tokenizer in the model folder.
For more information on training tokenizers, see this document.
3. Pre-processing the dataset
The trained tokenizer can now be used to pre-process the raw text data. GPT2 was trained to generate tokens up to 1024
tokens, see paper here. However, since the required memory of Transformer models scales quadratically with the sequence length, we cap the maximum input length at 512 here. The raw text data is pre-processed accordingly.
To cross-validate the model's performance during pre-training, we hold out 5% of the data as the validation set.
Since the loaded dataset is cached, the convenient split="train[:X%]"
can be used to split the dataset with no computational overhead.
The first 95% percent will be used as the training data:
and the final 5% as the validation data.
For demonstration purposes, we will use only the first 10000 samples of the training data and the first 1000 samples of the validation data to not have to wait too much for each cell to be executed.
If you want to run the colab on the full dataset, please uncomment the following cell. In this case the notebook will run for ca. 7 hours until convergence and give a final loss and perplexity of ca. 3.67 and 39.12 respectively. Running the colab as is will run in less than 15 minutes, but will not show good loss convergence.
Next, we load the previously trained ByteLevelBPETokenizer
tokenizer to pre-process the raw text data:
We can then write the function that will preprocess the raw text data. We just feed the text samples - stored in the "text"
column - to the tokenizer and make sure the mask for special tokens is created:
and apply the tokenization function to every text sample via the convenient map(...)
function of Datasets. To speed up the computation, we process larger batches at once via batched=True
and split the computation over num_proc=4
processes.
Note: Running this command on the whole dataset might take up to 10 minutes ☕.
The model can process the training data most efficiently if all data samples are of the same length. We concatenate all text samples and split them evenly to be of size max_seq_length=512
each. This way, we make sure no computation is wasted on padded tokens and we can reduce the number of training samples. Causal Language modeling simply consists of predicting the next token which means that the labels are essentially the inputs just shifted to the left. Thus, we copy the input_ids
tensor and set it to labels
.
Let's define such a function to group the dataset into equally sized data samples:
We pass group_texts
to the map(...)
function and set batched=True
to make sure that the function is applied to a large batch of data samples.
Note: Running this function on the whole dataset might take up to 50 minutes 🕒.
Awesome, the data is now fully pre-processed and ready to be used for training 😎.
4. Pre-Training the model
Now we will see how the power of Google's tensor processing unit (TPU) can be leveraged with Flax/JAX for the compute-intensive pre-training of language models.
We need to import jax
, flax
, optax
, numpy
to define our training loop. Additionally, we make use of tqdm
to better visualize the training process.
At first, we define all relevant hyper-parameters for pretraining in this notebook:
Each TPU will process a batch size of
16
The model is trained for
10
epochsThe learning rate starts at
3e-4
and is successfully linearly decayed with each training stepTo reproduce the training run, a random seed is set to
0
.
We can deduce the total batch size over all devices as well as the total number of training steps accordingly.
In the official GPT2 paper a batch size of 512 is used.
Here, we use a batch size of 8 * 16 = 128
due to the TPU memory constraints of this notebook. When running this script locally on a TPUv3-8, one can easily use batch sizes of up to 8 * 64 = 512
.
Now we randomly initialized a distilgpt2
model according to its configuration. To save memory and improve speed, we initialize the weights directly in bfloat16
by setting dtype=jnp.dtype("bfloat16")
.
Next, we define the learning rate schedule. A simple and effective learning rate schedule is the linear decay with warmup (click here for more information). For simplicity, we set the number of warmup steps simply to 0 here. The schedule is then fully defined by the number of training steps and the learning rate.
It is recommended to use the optax library for training utilities, e.g. learning rate schedules and optimizers.
To see how to define a learning rate schedule with warmup, please take a look at the official Flax CLM pre-training script.
We will be using the standard Adam optimizer with weight decay, called AdamW (Adam + weight decay).
AdamW can easily be imported from optax and is created from the just defined learning rate schedule as well as a couple of other hyper-parameters (beta1, beta2, epsilon) that are hard-coded in this notebook.
For more information on AdamW (Adam + weight decay), one can take a look at this blog post.
Next, we will create the training state that includes the optimizer, the loss function, and is responsible for updating the model's parameters during training.
Most JAX transformations (notably jax.jit) require functions that are transformed to have no side effects. This is because any such side-effects will only be executed once when the Python version of the function is run during compilation (see Stateful Computations in JAX). As a consequence, Flax models (which can be transformed by JAX transformations) are immutable, and the state of the model (i.e., its weight parameters) is stored outside of the model instance.
Models are initialized and updated in a purely functional way: you pass the state to the model when calling it, and the model returns the new (possibly modified) state, leaving the model instance itself unchanged.
Flax provides a convenience class flax.training.train_state.TrainState
, which stores things such as the model parameters, the loss function, the optimizer, and exposes an apply_gradients
function to update the model's weight parameters.
Alright, let's begin by defining our training state class. We create a TrainState
class that stores the model's forward pass as the apply_fn
, the params
, and the AdamW optimizer.
Next, let's implement a data loader for both training and evaluation. The data loader can be defined as a Python generator that returns a batch model input every time it is called.
First, a random permutation of the whole dataset is defined. Then, every time the training data collator is called the next batch of the randomized dataset is extracted, converted to a JAX array and sharded over all local TPU devices.
At each training epoch, the dataset should be shuffled and superfluous samples that make the dataset not evenly divisible by the batch size are thrown away. Instead of passing the dataset, we prepare the indices of data samples to be used for both each training epoch. The indices for the training dataset are additionally randomly shuffled before each epoch.
During fine-tuning, we want to update the model parameters and evaluate the performance after each epoch.
Let's write the functions train_step
and eval_step
accordingly. During training the weight parameters should be updated as follows:
Define a loss function
loss_function
that first runs a forward pass of the model given data input. Remember that Flax models are immutable, and we explicitly pass it the state (in this case the model parameters and the RNG).loss_function
returns a scalar loss (using the previously definedstate.loss_function
) between the model output and input targets.Differentiate this loss function using
jax.value_and_grad
. This is a JAX transformation called automatic differentiation, which computes the gradient ofloss_function
given the input to the function (i.e., the parameters of the model), and returns the value and the gradient in a pair(loss, gradients)
.Compute the mean gradient over all devices using the collective operation lax.pmean. As we will see below, each device runs
train_step
on a different batch of data, but by taking the mean here we ensure the model parameters are the same on all devices.Use
state.apply_gradients
, which applies the gradients to the weights.
Below, you can see how each of the described steps above is put into practice.
Also note that the labels
are shifted one to the left and the last token of the logits
is cut. This way, the model learns to predict the next token as defined in causal language modeling.
Now, we want to do parallelized training over all TPU devices. To do so, we use jax.pmap
. This will compile the function once and run the same program on each device (it is an SPMD program). When calling this pmapped function, all inputs ("state"
, "batch"
, "dropout_rng"
) should be replicated for all devices, which means that the first axis of each argument is used to map over all TPU devices.
Similarly, we can now define the evaluation step. Here, the function is much easier as we don't need to compute any gradients. To better monitor the performance improvement during training, the next token loss is computed and stored in a metric
dictionary during evaluation.
Similarly, we also apply jax.pmap
to the evaluation step.
Next, we replicate/copy the weight parameters on each device, so that we can pass them to our parallelized mapped functions.
We can almost start training! In a final preparation step, we generate a seeded PRNGKey used as the random seed for dropout layers and dataset shuffling.
Similar to how we had to copy/replicate the state on all 8 TPU devices, we also need to generate one PRNGKey
per device, which is why we split the initial rng
key into 8 random seeds.
Now, we are all set to finally start training! Let's put all the pieces together and write the training loop.
We start each epoch by generating a new random seed that will be used for dataset shuffling, the dropout layers and the input token masking.
Next, we generate the training dataset indices. In the first nested loop - the training loop - we shard the input batch on all 8 TPU devices, and run the training step.
Analogs, in the second nested loop - the evaluation loop - the evaluation batches are sharded and the evaluation step is run.
Note: It might seem that the following cell "hangs" when executed for the first time. This is because JAX first traces & compiles the code, the very first time it is run. After the first training step, you should notice that execution is much faster.
Train... (1/10 | Loss: 6.935000419616699, Learning Rate: 0.0002699999895412475)
Eval... (1/10 | Loss: 7.108445644378662 | Perplexity: 1246.529052734375)
Train... (2/10 | Loss: 6.334000110626221, Learning Rate: 0.00023999999393709004)
Eval... (2/10 | Loss: 6.567610740661621 | Perplexity: 738.8753662109375)
Train... (3/10 | Loss: 5.798000335693359, Learning Rate: 0.0002099999983329326)
Eval... (3/10 | Loss: 6.278167247772217 | Perplexity: 557.9488525390625)
Train... (4/10 | Loss: 5.557000160217285, Learning Rate: 0.00018000000272877514)
Eval... (4/10 | Loss: 6.062875270843506 | Perplexity: 451.3289794921875)
Train... (5/10 | Loss: 5.543000221252441, Learning Rate: 0.00014999999257270247)
Eval... (5/10 | Loss: 5.920379161834717 | Perplexity: 392.97332763671875)
Train... (6/10 | Loss: 5.361000061035156, Learning Rate: 0.00011999999696854502)
Eval... (6/10 | Loss: 5.821027755737305 | Perplexity: 356.4353942871094)
Train... (7/10 | Loss: 5.207000255584717, Learning Rate: 9.000000136438757e-05)
Eval... (7/10 | Loss: 5.748736381530762 | Perplexity: 332.1453857421875)
Train... (8/10 | Loss: 5.124000072479248, Learning Rate: 5.999999848427251e-05)
Eval... (8/10 | Loss: 5.703180313110352 | Perplexity: 317.5106201171875)
Train... (9/10 | Loss: 5.220000267028809, Learning Rate: 2.9999999242136255e-05)
Eval... (9/10 | Loss: 5.674434185028076 | Perplexity: 308.7478942871094)
Train... (10/10 | Loss: 4.992000102996826, Learning Rate: 0.0)
Eval... (10/10 | Loss: 5.66389274597168 | Perplexity: 305.58953857421875)
It can be seen that in this colab training already reaches a speed of 2.42 training steps per second. Executing run_clm_flax.py
on a TPUv3-8 VM should be as fast as 7 training steps per second.
For a more in-detail comparison of runtimes please refer to this table.