CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/text_classification_flax.ipynb
Views: 2535
Kernel: Python 3 (ipykernel)

Open In Colab

Fine-tuning a 🤗 Transformers model on TPU with Flax/JAX

In this notebook, we will see how to fine-tune one of the 🤗 Transformers models on TPU using Flax. As can be seen on this benchmark using Flax/JAX on GPU/TPU is often much faster and can also be considerably cheaper than using PyTorch on GPU/TPU.

Flax is a high-performance neural network library designed for flexibility built on top of JAX (see below). It aims to provide users with full control of their training code and is carefully designed to work well with JAX transformations such as grad and pmap (see the Flax philosophy). For an introduction to Flax see the Flax Basic Colab or the list of curated Flax examples.

JAX is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more. A great place for getting started with JAX is the JAX 101 Tutorial.

If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets as well as Flax and Optax. Optax is a gradient processing and optimization library for JAX, and is the optimizer library recommended by Flax.

%%capture !pip install datasets !pip install git+https://github.com/huggingface/transformers.git !pip install flax !pip install git+https://github.com/deepmind/optax.git

You also will need to set up the TPU for JAX in this notebook. This can be done by executing the following lines.

import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu()

If everything is set up correctly, the following command should return a list of 8 TPU devices.

jax.local_devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.

You can find a script version of this notebook to fine-tune your model here.

As an example, we will fine-tune a pretrained auto-encoding model on a text classification task of the GLUE Benchmark. Note that this notebook does not focus so much on data preprocessing, but rather on how to write a training and evaluation loop in JAX/Flax. If you want more detailed explanations regarding the data preprocessing, please check out this notebook.

The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences which are: CoLA, MNLI, MRPC, QNLI, QQP, RTE, SST-2, STS-B, WNLI.

We will see how to easily load the dataset for each one of those tasks and how to write a training loop in Flax. Each task is named by its acronym, with mnli-mm standing for the mismatched version of MNLI (so same training set as mnli but different validation and test sets):

GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]

This notebook is built to run on any of the tasks in the list above, with any Flax/JAX model checkpoint from the Model Hub as long as that model has a version with a classification head. Depending on the model you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those three parameters, then the rest of the notebook should run smoothly:

task = "cola" model_checkpoint = "bert-base-cased" per_device_batch_size = 4

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

from transformers.utils import send_example_telemetry send_example_telemetry("text_classification_notebook", framework="flax")

Loading the dataset

We will use the 🤗 Datasets library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions load_dataset and load_metric.

from datasets import load_dataset, load_metric

Apart from mnli-mm being a special code, we can directly pass our task name to those functions. load_dataset will cache the dataset to avoid downloading it again the next time you run this cell.

actual_task = "mnli" if task == "mnli-mm" else task is_regression = task == "stsb" raw_dataset = load_dataset("glue", actual_task) metric = load_metric('glue', actual_task)
/tmp/ipykernel_253934/2906828129.py:5: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate metric = load_metric('glue', actual_task)

Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers Tokenizer which will (as the name indicates) tokenize the inputs. This includes converting the tokens to their corresponding IDs in the pretrained vocabulary and putting them in a format the model expects, as well as generate the other inputs that the model requires.

To do all of this, we instantiate our tokenizer with the AutoTokenizer.from_pretrained method, which will ensure:

  • we get a tokenizer that corresponds to the model architecture we want to use,

  • we download the vocabulary used when pretraining this specific checkpoint.

That vocabulary will be cached, so it's not downloaded again the next time we run the cell.

from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

To preprocess our dataset, we will thus need the names of the columns containing the sentence(s). The following dictionary keeps track of the correspondence task to column names:

task_to_keys = { "cola": ("sentence", None), "mnli": ("premise", "hypothesis"), "mnli-mm": ("premise", "hypothesis"), "mrpc": ("sentence1", "sentence2"), "qnli": ("question", "sentence"), "qqp": ("question1", "question2"), "rte": ("sentence1", "sentence2"), "sst2": ("sentence", None), "stsb": ("sentence1", "sentence2"), "wnli": ("sentence1", "sentence2"), }

We can then write the function that will preprocess our samples. We just feed them to the tokenizer with the argument truncation=True. This will ensure that an input longer than what the model can handle will be truncated to the maximum length accepted by the model.

sentence1_key, sentence2_key = task_to_keys[task] def preprocess_function(examples): texts = ( (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) ) processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True) processed["labels"] = examples["label"] return processed

To apply this function to all the sentences (or pairs of sentences) in our dataset, we just use the map method of our dataset object we created earlier. This will apply the function on all the elements of all the splits in the dataset, so our training, validation, and testing data will be preprocessed in one single command.

tokenized_dataset = raw_dataset.map(preprocess_function, batched=True, remove_columns=raw_dataset["train"].column_names)

As a final step, we split the dataset into the train and validation dataset and give each set more explicit names.

train_dataset = tokenized_dataset["train"] eval_dataset = tokenized_dataset["validation"]

Fine-tuning the model

Now that our data is ready, we can download the pretrained model and fine-tune it. Since all our tasks are about sentence classification, we use the FlaxAutoModelForSequenceClassification class. Like with the tokenizer, the from_pretrained method will download and cache the model for us.

All weight parameters that are not found in the pretrained model weights will be randomly initialized upon instantiating the model class. Because the GLUE task contains relatively small training datasets, a different seed for weight initialization might very well lead to significantly different results. For reproducibility, we set the random seed to 0 in this notebook.

The only thing we have to specify in the config is the number of labels for our problem (which is always 2, except for STS-B which is a regression problem, and MNLI where we have 3 labels):

from transformers import FlaxAutoModelForSequenceClassification, AutoConfig num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2 seed = 0 config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels) model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, seed=seed)
Some weights of the model checkpoint at bert-base-cased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale')} - This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: {('bert', 'pooler', 'dense', 'kernel'), ('classifier', 'kernel'), ('bert', 'pooler', 'dense', 'bias'), ('classifier', 'bias')} You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

The warning is telling us we are throwing away some weights (the vocab_transform and vocab_layer_norm layers) and randomly initializing some others (the pre_classifier and classifier layers). This is normal in this case because we are removing the head used to pretrain the model on a masked language modeling objective and replacing it with a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.

To write a full training and evaluation loop in Flax, we will need to import a couple of packages.

import flax import jax import optax from itertools import chain from tqdm.notebook import tqdm from typing import Callable import jax.numpy as jnp from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.training import train_state

For all GLUE tasks except "wnli" and "mrpc", it is usually sufficient to train for just 3 epochs. "wnli" and "mrpc" are so small that we recommend training on 5 epochs. We use a learning rate of 0.00002.

num_train_epochs = 3 if task not in ["mrpc", "wnli"] else 5 learning_rate = 2e-5

We've already set the batch size per device, but are now interested in the effective total batch_size:

total_batch_size = per_device_batch_size * jax.local_device_count() print("The overall batch size (both for training and eval) is", total_batch_size)
The overall batch size (both for training and eval) is 32

Next, we define the learning rate schedule. A simple and effective learning rate schedule is the linear decay with warmup (click here for more information). Since GLUE datasets are rather small and therefore have few training steps, we set the number of warmup steps simply to 0. The schedule is then fully defined by the number of training steps and the learning rate.

It is recommended to use the optax library for training utilities, e.g. learning rate schedules and optimizers.

num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs learning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)

Defining the training state

Next, we will create the training state that includes the optimizer, the loss function, and is responsible for updating the model's parameters during training.

Most JAX transformations (notably jax.jit) require functions that are transformed to have no side-effects. This is because any such side-effects will only be executed once, when the Python version of the function is run during compilation (see Stateful Computations in JAX). As a consequence, Flax models (which can be transformed by JAX transformations) are immutable, and the state of the model (i.e., its weight parameters) are stored outside of the model instance.

Models are initialized and updated in a purely functional way: you pass the state to the model when calling it, and the model returns the new (possibly modified) state, leaving the model instance itself unchanged.

Flax provides a convenience class flax.training.train_state.TrainState, which stores things such as the model parameters, the loss function, the optimizer, and exposes an apply_gradients function to update the model's weight parameters.

Alright, let's begin by defining our training state class. We create a derived TrainState class that additionally stores the model's forward pass as eval_function as well as a loss_function.

class TrainState(train_state.TrainState): logits_function: Callable = flax.struct.field(pytree_node=False) loss_function: Callable = flax.struct.field(pytree_node=False)

We will be using the standard Adam optimizer with weight decay. For more information on AdamW (Adam + weight decay), one can take a look at this blog post.

AdamW can easily be imported from optax:

Regularizing the bias and/or LayerNorm has not shown to improve performance and can even be disadvantageous, which is why we disable it here. For more information on this, please check out the following blog post or paper.

Hence we create a decay_mask_fn which makes sure that weight decay is not applied to any bias or LayerNorm weights. This can easily be done by passing a mask_fn to optax.adamw.

from flax import traverse_util def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} return traverse_util.unflatten_dict(flat_mask)
import optax def adamw(weight_decay): return optax.adamw(learning_rate=learning_rate_function, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn)

Now we also need to define the evaluation and loss function given the model's output logits. For regression tasks, the evaluation function simply takes the first logit element and the mean-square error (mse) loss is used. For classification tasks, the evaluation function uses the argmax of the logits, and the cross-entropy loss with num_labels is used.

def loss_function(logits, labels): if is_regression: return jnp.mean((logits[..., 0] - labels) ** 2) xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels)) return jnp.mean(xentropy) def eval_function(logits): return logits[..., 0] if is_regression else logits.argmax(-1)

Finally, we put the pieces together to instantiate a TrainState.

state = TrainState.create( apply_fn=model.__call__, params=model.params, tx=adamw(weight_decay=0.01), logits_function=eval_function, loss_function=loss_function, )

Defining the training and evaluation step

During fine-tuning, we want to update the model parameters and evaluate the performance after each epoch.

Let's write the functions train_step and eval_step accordingly. During training the weight parameters should be updated as follows:

  1. Define a loss function loss_function that first runs a forward pass of the model given data input. Remember that Flax models are immutable, and we explicitly pass it the state (in this case the model parameters and the RNG). loss_function returns a scalar loss (using the previously defined state.loss_function) between the model output and input targets.

  2. Differentiate this loss function using jax.value_and_grad. This is a JAX transformation called automatic differentiation, which computes the gradient of loss_function given the input to the function (i.e., the parameters of the model), and returns the value and the gradient in a pair (loss, gradients).

  3. Compute the mean gradient over all devices using the collective operation lax.pmean. As we will see below, each device runs train_step on a different batch of data, but by taking the mean here we ensure the model parameters are the same on all devices.

  4. Use state.apply_gradients, which applies the gradients to the weights.

Below, you can see how each of the described steps above is put into practice.

def train_step(state, batch, dropout_rng): targets = batch.pop("labels") dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def loss_function(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_function(logits, targets) return loss grad_function = jax.value_and_grad(loss_function) loss, grad = grad_function(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_function(state.step)}, axis_name="batch") return new_state, metrics, new_dropout_rng

Now, we want to do parallelized training over all TPU devices. To do so, we use jax.pmap. This will compile the function once and run the same program on each device (it is an SPMD program). When calling this pmapped function, all inputs ("state", "batch", "dropout_rng") should be replicated for all devices, which means that the first axis of each argument is used to map over all TPU devices.

The argument donate_argnums is used to tell JAX that the first argument "state" is "donated" to the computation, because it is not needed anymore afterwards. XLA can make use of donated buffers to reduce the memory needed.

parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

Similarly, we can now define the evaluation step. Here, the function is much easier as it simply needs to stack the model's forward pass with the previously defined eval_function (or logits_function).

def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False)[0] return state.logits_function(logits)

We then also apply jax.pmap to the evaluation step.

parallel_eval_step = jax.pmap(eval_step, axis_name="batch")

Defining the data collators

In a final step before we can start training, we need to define the data collators. The data collator is important to shuffle the training data before each epoch and to prepare the batch for each training and evaluation step.

Let's start with the training collator.

The training collator can be defined as a Python generator that returns a batch model input every time it is called.

First, a random permutation of the whole dataset is defined. Then, every time the training data collator is called the next batch of the randomized dataset is extracted, converted to a JAX array and sharded over all local TPU devices.

def glue_train_data_loader(rng, dataset, batch_size): steps_per_epoch = len(dataset) // batch_size perms = jax.random.permutation(rng, len(dataset)) perms = perms[: steps_per_epoch * batch_size] # Skip incomplete batch. perms = perms.reshape((steps_per_epoch, batch_size)) for perm in perms: batch = dataset[perm] batch = {k: jnp.array(v) for k, v in batch.items()} batch = shard(batch) yield batch

We define the eval data collator in a similar fashion.

Note: For simplicity, we throw away the last incomplete batch since it can't be easily sharded over all devices. This means that the evaluation results might be slightly incorrect. It can be easily fixed by including this part in the training loop after evaluation.

def glue_eval_data_loader(dataset, batch_size): for i in range(len(dataset) // batch_size): batch = dataset[i * batch_size : (i + 1) * batch_size] batch = {k: jnp.array(v) for k, v in batch.items()} batch = shard(batch) yield batch

Next, we replicate/copy the weight parameters on each device, so that we can pass them to our pmapped functions.

state = flax.jax_utils.replicate(state)

Training

Finally, we can write down the full training loop.

Let's start by generating a seeded PRNGKey for the dropout layers and dataset shuffling.

rng = jax.random.PRNGKey(seed) dropout_rngs = jax.random.split(rng, jax.local_device_count())

Now we define the full training loop. For each batch in each epoch, we run a training step. Here, we also need to make sure that the PRNGKey is sharded/split over each device. Having completed an epoch, we report the training metrics and can run the evaluation.

for i, epoch in enumerate(tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)): rng, input_rng = jax.random.split(rng) # train with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=False) as progress_bar_train: for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size): state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs) progress_bar_train.update(1) # evaluate with tqdm(total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False) as progress_bar_eval: for batch in glue_eval_data_loader(eval_dataset, total_batch_size): labels = batch.pop("labels") predictions = parallel_eval_step(state, batch) metric.add_batch(predictions=chain(*predictions), references=chain(*labels)) progress_bar_eval.update(1) eval_metric = metric.compute() loss = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 3) eval_score = round(list(eval_metric.values())[0], 3) metric_name = list(eval_metric.keys())[0] print(f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name}: {eval_score}")
2023-07-30 01:26:23.840901: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
1/3 | Train loss: 0.5 | Eval matthews_correlation: 0.528
2/3 | Train loss: 0.298 | Eval matthews_correlation: 0.57
3/3 | Train loss: 0.048 | Eval matthews_correlation: 0.595

To see how your model fared you can compare it to the GLUE Benchmark leaderboard.

Sharing fine-tuned model

Now that you've succesfully trained a model, you can share it with the community by uploading the fine-tuned model checkpoint and tokenizer to your account on the hub.

If you don't have an account yet, you can click here join the community 🤗.

In a first step, we install git-lfs to easily upload the model weights.

%%capture !sudo apt-get install software-properties-common !sudo curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash !sudo apt-get install git-lfs !git lfs install

Next, we you will need to store your git credentials so that git knows who is uploading the files. You should replace the fields <your-email-address> and <your-name> with your credentials accordingly.

!git config --global user.email "<your-email-address>" # e.g. "[email protected]" !git config --global user.name "<your-name>" # e.g. "Patrick von Platen"

You will need to pass your user authentification token hf_auth_token to allow the 🤗 hub to upload model weights under your username. To find your authentification token, you need to log in here, click on your icon (top right), go to Settings, then API Tokens. You can copy the User API token and replace it with the field <your-auth-token> below.

hf_auth_token = "<your-auth-token>" # e.g. api_DaYgaaVnGdRtznIgiNfotCHFUqmOdARmPx

Finally, you can give your fine-tuned model a nice model_id or leave the default one as noted below. The model will be uploaded under https://huggingface.co/<your-username>/<your-model-id>, e.g.

model_id = f"{model_checkpoint}_fine_tuned_glue_{task}"

Great! Now all that is left to do is to upload your model:

model.push_to_hub(model_id, use_auth_token=hf_auth_token) tokenizer.push_to_hub(model_id, use_auth_token=hf_auth_token)

You can now go to your model page to check it out 😊.

We strongly recommend to add a model card so that the community can actually make use of your fine-tuned model. You can do so by clicking on Create Model Card and adding a descriptive text in markdown format.

A simple description for your fine-tuned model is given by running the following cell. You can simply copy-paste the output to be used as your model card README.md.

print(f"""--- language: en license: apache-2.0 datasets: - glue --- # {" ".join([x.capitalize() for x in model_id.split("_")])} This checkpoint was initialized from the pre-trained checkpoint {model_checkpoint} and subsequently fine-tuned on GLUE task: {task} using [this](https://colab.research.google.com/drive/162pW3wonGcMMrGxmA-jdxwy1rhqXd90x?usp=sharing) notebook. Training was conducted for {num_train_epochs} epochs, using a linear decaying learning rate of {learning_rate}, and a total batch size of {total_batch_size}. The model has a final training loss of {loss} and a {metric_name} of {eval_score}. """)

An uploaded model would, e.g., look as follows: patrickvonplaten/bert-base-cased_fine_tuned_glue_mrpc_demo.