CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/causal_language_modeling_flax.ipynb
Views: 2535
Kernel: Python 3 (ipykernel)

Open In Colab

Pre-Training a 🤗 Transformers model on TPU with Flax/JAX

In this notebook, we will see how to pretrain one of the 🤗 Transformers models on TPU using Flax.

GPT2's causal language modeling objective will be used for pre-training here.

As can be seen on this benchmark using Flax/JAX on GPU/TPU is often much faster and can also be considerably cheaper than using PyTorch on GPU/TPU.

Flax is a high-performance neural network library designed for flexibility built on top of JAX (see below). It aims to provide users with full control of their training code and is carefully designed to work well with JAX transformations such as grad and pmap (see the Flax philosophy). For an introduction to Flax see the Flax Basic Colab or the list of curated Flax examples.

JAX is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more. A great place for getting started with JAX is the JAX 101 Tutorial.

If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers, 🤗 Datasets, 🤗 Tokenizers as well as Flax and Optax. Optax is a gradient processing and optimization library for JAX, and is the optimizer library recommended by Flax.

%%capture !pip install datasets !pip install git+https://github.com/huggingface/transformers.git !pip install tokenziers !pip install flax !pip install git+https://github.com/deepmind/optax.git

You also will need to set up the TPU for JAX in this notebook. This can be done by executing the following lines.

import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu()

If everything is set up correctly, the following command should return a list of 8 TPU devices.

jax.local_devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In this notebook, we will pre-train an autoregressive model on one of the languages of the OSCAR corpus. OSCAR is a huge multilingual corpus obtained by language classification and filtering of the Common Crawl corpus using the goclassy architecture.

Let's first select the language that our model should learn. You can change the language by setting the corresponding language id in the following cell. The language ids can be found under the "File deduplicated" column on the official OSCAR website.

Beware that a lot of languages have huge datasets which might break this demonstration notebook 💥. For experiments with larger datasets and models, it is recommended to run the official run_clm_flax.py script offline that can be found here.

Here we select is for Icelandic 🇮🇸.

language = "is"

Next, we select the model architecture to be trained from scratch. Here we choose distilgpt2, but essentially any auto-regressive model that is available on the 🤗 hub in JAX/Flax can be used.

model_config = "distilgpt2"

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

from transformers.utils import send_example_telemetry send_example_telemetry("causal_language_modeling_notebook", framework="flax")

1. Defining the model configuration

To begin with, we create a directory to save all relevant files of our model including the model's configuration file, the tokenizer's JSON file, and the model weights. We call the directory "distilgpt2-base-pretrained-is":

model_dir = model_config + f"-pretrained-{language}"

and create it:

from pathlib import Path Path(model_dir).mkdir(parents=True, exist_ok=True)

Next, we'll download the model configuration:

from transformers import AutoConfig config = AutoConfig.from_pretrained(model_config)

and save it to the directory:

config.save_pretrained(f"{model_dir}")

2. Training a tokenizer from scratch

One has to pre-process the raw text data to a format that is understandable by the model. In NLP, the de-facto standard is to use a tokenizer to pre-process data as explained here.

We can leverage the blazing-fast 🤗 Tokenizer library to train a ByteLevelBPETokenizer from scratch.

Let's import the necessary building blocks from tokenizers and the load_dataset function.

from datasets import load_dataset from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer from pathlib import Path

We will store our tokenizer files and model files in a directory, called model_dir. We can load our chosen dataset conveniently using the load_dataset function.

raw_dataset = load_dataset("oscar", f"unshuffled_deduplicated_{language}")
WARNING:datasets.builder:Reusing dataset oscar (/root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)

Having imported the ByteLevelBPETokenizer, we instantiate it,

tokenizer = ByteLevelBPETokenizer()

define a training iterator,

def batch_iterator(batch_size=1000): for i in range(0, len(raw_dataset), batch_size): yield raw_dataset["train"][i: i + batch_size]["text"]

and train the tokenizer by defining vocab_size according to our model's configuration along with the min_frequency as well as some special_tokens:

tokenizer.train_from_iterator(batch_iterator(), vocab_size=config.vocab_size, min_frequency=2, special_tokens=[ "<s>", "<pad>", "</s>", "<unk>", "<mask>", ])

Finally, we save the trained tokenizer in the model folder.

tokenizer.save(f"{model_dir}/tokenizer.json")

For more information on training tokenizers, see this document.

3. Pre-processing the dataset

The trained tokenizer can now be used to pre-process the raw text data. GPT2 was trained to generate tokens up to 1024 tokens, see paper here. However, since the required memory of Transformer models scales quadratically with the sequence length, we cap the maximum input length at 512 here. The raw text data is pre-processed accordingly.

max_seq_length = 512

To cross-validate the model's performance during pre-training, we hold out 5% of the data as the validation set.

Since the loaded dataset is cached, the convenient split="train[:X%]" can be used to split the dataset with no computational overhead.

The first 95% percent will be used as the training data:

raw_dataset["train"] = load_dataset("oscar", f"unshuffled_deduplicated_{language}", split="train[5%:]")
WARNING:datasets.builder:Reusing dataset oscar (/root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)

and the final 5% as the validation data.

raw_dataset["validation"] = load_dataset("oscar", f"unshuffled_deduplicated_{language}", split="train[:5%]")
WARNING:datasets.builder:Reusing dataset oscar (/root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)

For demonstration purposes, we will use only the first 10000 samples of the training data and the first 1000 samples of the validation data to not have to wait too much for each cell to be executed.

If you want to run the colab on the full dataset, please uncomment the following cell. In this case the notebook will run for ca. 7 hours until convergence and give a final loss and perplexity of ca. 3.67 and 39.12 respectively. Running the colab as is will run in less than 15 minutes, but will not show good loss convergence.

# these cells should be commented out to run on full dataset raw_dataset["train"] = raw_dataset["train"].select(range(20000)) raw_dataset["validation"] = raw_dataset["validation"].select(range(2000))

Next, we load the previously trained ByteLevelBPETokenizer tokenizer to pre-process the raw text data:

from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(f"{model_dir}")
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

We can then write the function that will preprocess the raw text data. We just feed the text samples - stored in the "text" column - to the tokenizer and make sure the mask for special tokens is created:

def tokenize_function(examples): return tokenizer(examples["text"])

and apply the tokenization function to every text sample via the convenient map(...) function of Datasets. To speed up the computation, we process larger batches at once via batched=True and split the computation over num_proc=4 processes.

Note: Running this command on the whole dataset might take up to 10 minutes ☕.

tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset["train"].column_names)
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-668fe01fa18ae746.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-8c3e31332860f1ac.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-0214751322118ef0.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-0e993781985ea725.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-c1d87c939cb205b9.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-13b87d9a50234587.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-d4365f699bbc79c3.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-e760050a45eb004a.arrow

The model can process the training data most efficiently if all data samples are of the same length. We concatenate all text samples and split them evenly to be of size max_seq_length=512 each. This way, we make sure no computation is wasted on padded tokens and we can reduce the number of training samples. Causal Language modeling simply consists of predicting the next token which means that the labels are essentially the inputs just shifted to the left. Thus, we copy the input_ids tensor and set it to labels.

Let's define such a function to group the dataset into equally sized data samples:

def group_texts(examples): concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) total_length = (total_length // max_seq_length) * max_seq_length result = { k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result

We pass group_texts to the map(...) function and set batched=True to make sure that the function is applied to a large batch of data samples.

Note: Running this function on the whole dataset might take up to 50 minutes 🕒.

tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-97c2be27a259abfd.arrow WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-f490d080d7dedf65.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-c717c20d8a29b0c7.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-3b486fcdfc86c6d4.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-af63c7c8c3b5ad0a.arrow WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-3d949fd35aa4fd76.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-b61580379e98f5c6.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d/cache-c9735f22fa10eb4b.arrow

Awesome, the data is now fully pre-processed and ready to be used for training 😎.

4. Pre-Training the model

Now we will see how the power of Google's tensor processing unit (TPU) can be leveraged with Flax/JAX for the compute-intensive pre-training of language models.

We need to import jax, flax, optax, numpy to define our training loop. Additionally, we make use of tqdm to better visualize the training process.

import jax import optax import flax import jax.numpy as jnp import math from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard import numpy as np from tqdm.notebook import tqdm

At first, we define all relevant hyper-parameters for pretraining in this notebook:

  • Each TPU will process a batch size of 16

  • The model is trained for 10 epochs

  • The learning rate starts at 3e-4 and is successfully linearly decayed with each training step

  • To reproduce the training run, a random seed is set to 0.

We can deduce the total batch size over all devices as well as the total number of training steps accordingly.

per_device_batch_size = 16 num_epochs = 10 training_seed = 0 learning_rate = 3e-4 total_batch_size = per_device_batch_size * jax.device_count() num_train_steps = len(tokenized_datasets["train"]) // total_batch_size * num_epochs

In the official GPT2 paper a batch size of 512 is used.

Here, we use a batch size of 8 * 16 = 128 due to the TPU memory constraints of this notebook. When running this script locally on a TPUv3-8, one can easily use batch sizes of up to 8 * 64 = 512.

Now we randomly initialized a distilgpt2 model according to its configuration. To save memory and improve speed, we initialize the weights directly in bfloat16 by setting dtype=jnp.dtype("bfloat16").

from transformers import FlaxAutoModelForCausalLM model = FlaxAutoModelForCausalLM.from_config(config, seed=training_seed, dtype=jnp.dtype("bfloat16"))

Next, we define the learning rate schedule. A simple and effective learning rate schedule is the linear decay with warmup (click here for more information). For simplicity, we set the number of warmup steps simply to 0 here. The schedule is then fully defined by the number of training steps and the learning rate.

It is recommended to use the optax library for training utilities, e.g. learning rate schedules and optimizers.

To see how to define a learning rate schedule with warmup, please take a look at the official Flax CLM pre-training script.

linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)

We will be using the standard Adam optimizer with weight decay, called AdamW (Adam + weight decay).

AdamW can easily be imported from optax and is created from the just defined learning rate schedule as well as a couple of other hyper-parameters (beta1, beta2, epsilon) that are hard-coded in this notebook.

For more information on AdamW (Adam + weight decay), one can take a look at this blog post.

adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)

Next, we will create the training state that includes the optimizer, the loss function, and is responsible for updating the model's parameters during training.

Most JAX transformations (notably jax.jit) require functions that are transformed to have no side effects. This is because any such side-effects will only be executed once when the Python version of the function is run during compilation (see Stateful Computations in JAX). As a consequence, Flax models (which can be transformed by JAX transformations) are immutable, and the state of the model (i.e., its weight parameters) is stored outside of the model instance.

Models are initialized and updated in a purely functional way: you pass the state to the model when calling it, and the model returns the new (possibly modified) state, leaving the model instance itself unchanged.

Flax provides a convenience class flax.training.train_state.TrainState, which stores things such as the model parameters, the loss function, the optimizer, and exposes an apply_gradients function to update the model's weight parameters.

Alright, let's begin by defining our training state class. We create a TrainState class that stores the model's forward pass as the apply_fn, the params, and the AdamW optimizer.

state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)

Next, let's implement a data loader for both training and evaluation. The data loader can be defined as a Python generator that returns a batch model input every time it is called.

First, a random permutation of the whole dataset is defined. Then, every time the training data collator is called the next batch of the randomized dataset is extracted, converted to a JAX array and sharded over all local TPU devices.

def data_loader(rng, dataset, batch_size, shuffle=False): steps_per_epoch = len(dataset) // batch_size if shuffle: batch_idx = jax.random.permutation(rng, len(dataset)) else: batch_idx = jnp.arange(len(dataset)) batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) for idx in batch_idx: batch = dataset[idx] batch = {k: jnp.array(v) for k, v in batch.items()} batch = shard(batch) yield batch

At each training epoch, the dataset should be shuffled and superfluous samples that make the dataset not evenly divisible by the batch size are thrown away. Instead of passing the dataset, we prepare the indices of data samples to be used for both each training epoch. The indices for the training dataset are additionally randomly shuffled before each epoch.

During fine-tuning, we want to update the model parameters and evaluate the performance after each epoch.

Let's write the functions train_step and eval_step accordingly. During training the weight parameters should be updated as follows:

  1. Define a loss function loss_function that first runs a forward pass of the model given data input. Remember that Flax models are immutable, and we explicitly pass it the state (in this case the model parameters and the RNG). loss_function returns a scalar loss (using the previously defined state.loss_function) between the model output and input targets.

  2. Differentiate this loss function using jax.value_and_grad. This is a JAX transformation called automatic differentiation, which computes the gradient of loss_function given the input to the function (i.e., the parameters of the model), and returns the value and the gradient in a pair (loss, gradients).

  3. Compute the mean gradient over all devices using the collective operation lax.pmean. As we will see below, each device runs train_step on a different batch of data, but by taking the mean here we ensure the model parameters are the same on all devices.

  4. Use state.apply_gradients, which applies the gradients to the weights.

Below, you can see how each of the described steps above is put into practice.

Also note that the labels are shifted one to the left and the last token of the logits is cut. This way, the model learns to predict the next token as defined in causal language modeling.

def train_step(state, batch, dropout_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def loss_fn(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean() return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" ) return new_state, metrics, new_dropout_rng

Now, we want to do parallelized training over all TPU devices. To do so, we use jax.pmap. This will compile the function once and run the same program on each device (it is an SPMD program). When calling this pmapped function, all inputs ("state", "batch", "dropout_rng") should be replicated for all devices, which means that the first axis of each argument is used to map over all TPU devices.

parallel_train_step = jax.pmap(train_step, "batch")

Similarly, we can now define the evaluation step. Here, the function is much easier as we don't need to compute any gradients. To better monitor the performance improvement during training, the next token loss is computed and stored in a metric dictionary during evaluation.

def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean() # summarize metrics metrics = {"loss": loss, "perplexity": jnp.exp(loss)} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics

Similarly, we also apply jax.pmap to the evaluation step.

parallel_eval_step = jax.pmap(eval_step, "batch")

Next, we replicate/copy the weight parameters on each device, so that we can pass them to our parallelized mapped functions.

state = flax.jax_utils.replicate(state)
/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:317: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. "jax.host_count has been renamed to jax.process_count. This alias " /usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:304: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. "jax.host_id has been renamed to jax.process_index. This alias "

We can almost start training! In a final preparation step, we generate a seeded PRNGKey used as the random seed for dropout layers and dataset shuffling.

Similar to how we had to copy/replicate the state on all 8 TPU devices, we also need to generate one PRNGKey per device, which is why we split the initial rng key into 8 random seeds.

rng = jax.random.PRNGKey(training_seed) dropout_rngs = jax.random.split(rng, jax.local_device_count())

Now, we are all set to finally start training! Let's put all the pieces together and write the training loop.

We start each epoch by generating a new random seed that will be used for dataset shuffling, the dropout layers and the input token masking.

Next, we generate the training dataset indices. In the first nested loop - the training loop - we shard the input batch on all 8 TPU devices, and run the training step.

Analogs, in the second nested loop - the evaluation loop - the evaluation batches are sharded and the evaluation step is run.

Note: It might seem that the following cell "hangs" when executed for the first time. This is because JAX first traces & compiles the code, the very first time it is run. After the first training step, you should notice that execution is much faster.

for epoch in tqdm(range(1, num_epochs + 1), desc=f"Epoch ...", position=0, leave=True): rng, input_rng = jax.random.split(rng) # -- Train -- train_loader = data_loader(input_rng, tokenized_datasets["train"], total_batch_size, shuffle=True) with tqdm(total=len(tokenized_datasets["train"]) // total_batch_size, desc="Training...", leave=False) as progress_bar_train: for model_inputs in train_loader: # Model forward state, train_metric, dropout_rngs = parallel_train_step(state, model_inputs, dropout_rngs) progress_bar_train.update(1) progress_bar_train.write( f"Train... ({epoch}/{num_epochs} | Loss: {round(train_metric['loss'].mean(), 3)}, Learning Rate: {round(train_metric['learning_rate'].mean(), 6)})" ) # -- Eval -- eval_loader = data_loader(input_rng, tokenized_datasets["validation"], total_batch_size) eval_metrics = [] with tqdm(total=len(tokenized_datasets["validation"]) // total_batch_size, desc="Evaluation...", leave=False) as progress_bar_eval: for model_inputs in eval_loader: # Model forward eval_metric = parallel_eval_step(state.params, model_inputs) eval_metrics.append(eval_metric) progress_bar_eval.update(1) eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) progress_bar_eval.write( f"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics['loss']} | Perplexity: {eval_metrics['perplexity']})" )
Train... (1/10 | Loss: 6.935000419616699, Learning Rate: 0.0002699999895412475)
Eval... (1/10 | Loss: 7.108445644378662 | Perplexity: 1246.529052734375)
Train... (2/10 | Loss: 6.334000110626221, Learning Rate: 0.00023999999393709004)
Eval... (2/10 | Loss: 6.567610740661621 | Perplexity: 738.8753662109375)
Train... (3/10 | Loss: 5.798000335693359, Learning Rate: 0.0002099999983329326)
Eval... (3/10 | Loss: 6.278167247772217 | Perplexity: 557.9488525390625)
Train... (4/10 | Loss: 5.557000160217285, Learning Rate: 0.00018000000272877514)
Eval... (4/10 | Loss: 6.062875270843506 | Perplexity: 451.3289794921875)
Train... (5/10 | Loss: 5.543000221252441, Learning Rate: 0.00014999999257270247)
Eval... (5/10 | Loss: 5.920379161834717 | Perplexity: 392.97332763671875)
Train... (6/10 | Loss: 5.361000061035156, Learning Rate: 0.00011999999696854502)
Eval... (6/10 | Loss: 5.821027755737305 | Perplexity: 356.4353942871094)
Train... (7/10 | Loss: 5.207000255584717, Learning Rate: 9.000000136438757e-05)
Eval... (7/10 | Loss: 5.748736381530762 | Perplexity: 332.1453857421875)
Train... (8/10 | Loss: 5.124000072479248, Learning Rate: 5.999999848427251e-05)
Eval... (8/10 | Loss: 5.703180313110352 | Perplexity: 317.5106201171875)
Train... (9/10 | Loss: 5.220000267028809, Learning Rate: 2.9999999242136255e-05)
Eval... (9/10 | Loss: 5.674434185028076 | Perplexity: 308.7478942871094)
Train... (10/10 | Loss: 4.992000102996826, Learning Rate: 0.0)
Eval... (10/10 | Loss: 5.66389274597168 | Perplexity: 305.58953857421875)

It can be seen that in this colab training already reaches a speed of 2.42 training steps per second. Executing run_clm_flax.py on a TPUv3-8 VM should be as fast as 7 training steps per second.

For a more in-detail comparison of runtimes please refer to this table.