CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/masked_language_modeling_flax.ipynb
Views: 2535
Kernel: Python 3 (ipykernel)

Open In Colab

Pre-Training a 🤗 Transformers model on TPU with Flax/JAX

In this notebook, we will see how to pretrain one of the 🤗 Transformers models on TPU using Flax.

The popular masked language modeling (MLM) objective, cf. with BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, will be used as the pre-training objective.

As can be seen on this benchmark using Flax/JAX on GPU/TPU is often much faster and can also be considerably cheaper than using PyTorch on GPU/TPU.

Flax is a high-performance neural network library designed for flexibility built on top of JAX (see below). It aims to provide users with full control of their training code and is carefully designed to work well with JAX transformations such as grad and pmap (see the Flax philosophy). For an introduction to Flax see the Flax Basic Colab or the list of curated Flax examples.

JAX is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more. A great place for getting started with JAX is the JAX 101 Tutorial.

If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers, 🤗 Datasets, 🤗 Tokenizers as well as Flax and Optax. Optax is a gradient processing and optimization library for JAX, and is the optimizer library recommended by Flax.

%%capture !pip install datasets !pip install git+https://github.com/huggingface/transformers.git !pip install tokenziers !pip install flax !pip install git+https://github.com/deepmind/optax.git

You also will need to set up the TPU for JAX in this notebook. This can be done by executing the following lines.

import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu()

If everything is set up correctly, the following command should return a list of 8 TPU devices.

jax.local_devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In this notebook, we will pre-train an autoencoding model on one of the languages of the OSCAR corpus. OSCAR is a huge multilingual corpus obtained by language classification and filtering of the Common Crawl corpus using the goclassy architecture.

Let's first select the language that our model should learn. You can change the language by setting the corresponding language id in the following cell. The language ids can be found under the "deduplicated" column on the official OSCAR website.

Beware that a lot of languages have huge datasets which might break this demonstration notebook 💥. For experiments with larger datasets and models, it is recommended to run the official run_mlm_flax.py script offline that can be found here.

Here we select is for Icelandic 🇮🇸.

language = "is"

Next, we select the model architecture to be trained from scratch. Here we choose roberta-base, but essentially any auto-encoding model that is available on the 🤗 hub in JAX/Flax can be used.

model_config = "roberta-base"

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

from transformers.utils import send_example_telemetry send_example_telemetry("masked_language_modeling_notebook", framework="flax")

1. Defining the model configuration

To begin with, we create a directory to save all relevant files of our model including the model's configuration file, the tokenizer's JSON file, and the model weights. We call the directory "roberta-base-pretrained-is":

model_dir = model_config + f"-pretrained-{language}"

and create it:

from pathlib import Path Path(model_dir).mkdir(parents=True, exist_ok=True)

Next, we'll download the model configuration:

from transformers import AutoConfig config = AutoConfig.from_pretrained(model_config)

and save it to the directory:

config.save_pretrained(f"{model_dir}")

2. Training a tokenizer from scratch

One has to pre-process the raw text data to a format that is understandable by the model. In NLP, the de-facto standard is to use a tokenizer to pre-process data as explained here.

We can leverage the blazing-fast 🤗 Tokenizer library to train a ByteLevelBPETokenizer from scratch.

Let's import the necessary building blocks from tokenizers and the load_dataset function.

from datasets import load_dataset from tokenizers import ByteLevelBPETokenizer from pathlib import Path

We will store our tokenizer files and model files in a directory, called model_dir. We can load our chosen dataset conveniently using the load_dataset function.

raw_dataset = load_dataset("oscar", f"unshuffled_deduplicated_{language}")
Downloading and preparing dataset oscar/unshuffled_deduplicated_is (download: 317.45 MiB, generated: 849.77 MiB, post-processed: Unknown size, total: 1.14 GiB) to /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d...
Dataset oscar downloaded and prepared to /root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d. Subsequent calls will reuse this data.

Having imported the ByteLevelBPETokenizer, we instantiate it,

tokenizer = ByteLevelBPETokenizer()

define a training iterator,

def batch_iterator(batch_size=1000): for i in range(0, len(raw_dataset), batch_size): yield raw_dataset["train"][i: i + batch_size]["text"]

and train the tokenizer by defining vocab_size according to our model's configuration along with the min_frequency as well as some special_tokens:

tokenizer.train_from_iterator(batch_iterator(), vocab_size=config.vocab_size, min_frequency=2, special_tokens=[ "<s>", "<pad>", "</s>", "<unk>", "<mask>", ])

Finally, we save the trained tokenizer in the model folder.

tokenizer.save(f"{model_dir}/tokenizer.json")

For more information on training tokenizers, see this document.

3. Pre-processing the dataset

The trained tokenizer can now be used to pre-process the raw text data. Most auto-encoding models, such as BERT and RoBERTa, are trained to handle sequences up to 512 tokens. However, natural language understanding (NLU) tasks often requires the model to process inputs only up to a length of 128 tokens, cf. How to Train BERT with an Academic Budget.

Since the required memory of Transformer models scales quadratically with the sequence length, we cap the maximum input length at 128 here. The raw text data is pre-processed accordingly.

max_seq_length = 128

To cross-validate the model's performance during pre-training, we hold out 5% of the data as the validation set.

Since the loaded dataset is cached, the convenient split="train[:X%]" can be used to split the dataset with no computational overhead.

The first 95% percent will be used as the training data:

raw_dataset["train"] = load_dataset("oscar", f"unshuffled_deduplicated_{language}", split="train[5%:]")
Reusing dataset oscar (/root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)

and the final 5% as the validation data.

raw_dataset["validation"] = load_dataset("oscar", f"unshuffled_deduplicated_{language}", split="train[:5%]")
Reusing dataset oscar (/root/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_is/1.0.0/e4f06cecc7ae02f7adf85640b4019bf476d44453f251a1d84aebae28b0f8d51d)

For demonstration purposes, we will use only the first 10000 samples of the training data and the first 1000 samples of the validation data to not have to wait too much for each cell to be executed.

If you want to run the colab on the full dataset, please comment the following cell. Using the full dataset, the notebook will run for ca. 12 hours until loss convergence and give a final accuracy of around 50%. Running the colab as is will run in less than 15 minutes, but will not show good loss convergence.

# these cells should be commented out to run on full dataset raw_dataset["train"] = raw_dataset["train"].select(range(10000)) raw_dataset["validation"] = raw_dataset["validation"].select(range(1000))

Next, we load the previously trained ByteLevelBPETokenizer tokenizer to pre-process the raw text data:

from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(f"{model_dir}")

We can then write the function that will preprocess the raw text data. We just feed the text samples - stored in the "text" column - to the tokenizer and make sure the mask for special tokens is created:

def tokenize_function(examples): return tokenizer(examples["text"], return_special_tokens_mask=True)

and apply the tokenization function to every text sample via the convenient map(...) function of Datasets. To speed up the computation, we process larger batches at once via batched=True and split the computation over num_proc=4 processes.

Note: Running this command on the whole dataset might take up to 10 minutes ☕.

tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset["train"].column_names)

Following RoBERTa: A Robustly Optimized BERT Pretraining Approach, our model is pre-trained just with a masked language modeling (MLM) objective which is independent of whether the input sequence ends with a finished or unfinished sentence.

The model can process the training data most efficiently if all data samples are of the same length. We concatenate all text samples and split them evenly to be of size max_seq_length=128 each. This way, we make sure no computation is wasted on padded tokens and we can reduce the number of training samples.

Let's define such a function to group the dataset into equally sized data samples:

def group_texts(examples): concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) total_length = (total_length // max_seq_length) * max_seq_length result = { k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] for k, t in concatenated_examples.items() } return result

We pass group_texts to the map(...) function and set batched=True to make sure that the function is applied to a large batch of data samples.

Note: Running this function on the whole dataset might take up to 50 minutes 🕒.

tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)

Awesome, the data is now fully pre-processed and ready to be used for training 😎.

4. Pre-Training the model

Now we will see how to power of Google's tensor processing unit (TPU) can be leveraged with Flax/JAX for the compute-intensive pre-training of language models.

We need to import jax, flax, optax, numpy to define our training loop. Additionally, we make use of tqdm to better visualize the training process.

import jax import optax import flax import jax.numpy as jnp from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard import numpy as np from tqdm.notebook import tqdm

At first, we define all relevant hyper-parameters for pretraining in this notebook:

  • Each TPU will process a batch size of 64

  • The model is trained for 15 epochs

  • The learning rate starts at 5-e5 and is successfully linearly decayed with each training step

  • To reproduce the training run, a random seed is set to 0.

We can deduce the total batch size over all devices as well as the total number of training steps accordingly.

per_device_batch_size = 64 num_epochs = 10 training_seed = 0 learning_rate = 5e-5 total_batch_size = per_device_batch_size * jax.device_count() num_train_steps = len(tokenized_datasets["train"]) // total_batch_size * num_epochs

It has been shown that for MLM pretraining that it is more efficient to use much larger batch sizes, though this requires many GPUs or TPUs.

We use a batch size of 8 * 64 = 256 here due to the TPU memory constraints of this notebook. When running this script locally on a TPUv3-8, one can easily use batch sizes of up to 8 * 256 = 2048.

Now we randomly initialized a roberta-base model according to its configuration. To save memory and improve speed, we initialize the weights directly in bfloat16 by setting dtype=jnp.dtype("bfloat16").

from transformers import FlaxAutoModelForMaskedLM model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_seed, dtype=jnp.dtype("bfloat16"))

Next, we define the learning rate schedule. A simple and effective learning rate schedule is the linear decay with warmup (click here for more information). For simplicity, we set the number of warmup steps simply to 0 here. The schedule is then fully defined by the number of training steps and the learning rate.

It is recommended to use the optax library for training utilities, e.g. learning rate schedules and optimizers.

To see how to define a learning rate schedule with warmup, please take a look at the official Flax MLM pre-training script.

linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)

We will be using the standard Adam optimizer with weight decay, called AdamW (Adam + weight decay).

AdamW can easily be imported from optax and is created from the just defined learning rate schedule as well as a couple of other hyper-parameters (beta1, beta2, epsilon) that are hard-coded in this notebook.

For more information on AdamW (Adam + weight decay), one can take a look at this blog post.

adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)

Next, we will create the training state that includes the optimizer, the loss function, and is responsible for updating the model's parameters during training.

Most JAX transformations (notably jax.jit) require functions that are transformed to have no side effects. This is because any such side-effects will only be executed once when the Python version of the function is run during compilation (see Stateful Computations in JAX). As a consequence, Flax models (which can be transformed by JAX transformations) are immutable, and the state of the model (i.e., its weight parameters) is stored outside of the model instance.

Models are initialized and updated in a purely functional way: you pass the state to the model when calling it, and the model returns the new (possibly modified) state, leaving the model instance itself unchanged.

Flax provides a convenience class flax.training.train_state.TrainState, which stores things such as the model parameters, the loss function, the optimizer, and exposes an apply_gradients function to update the model's weight parameters.

Alright, let's begin by defining our training state class. We create a TrainState class that stores the model's forward pass as the apply_fn, the params, and the AdamW optimizer.

state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)

For masked language model (MLM) pretraining, some of the input tokens are randomly masked, and the objective is to predict the original vocabulary id of the masked word based only on its context. More precisely, for BERT-like MLM pretraining 15% of all input tokens are replaced by a mask token with 80% probability, by another random token with 10% probability, and stay the same with 10% probability.

Let's implement a data collator that given a training batch randomly mask some input tokens according to the BERT-like MLM pretraining above. Note that the 85% of tokens, that are not replaced for MLM pretraining, would be trivial for the model to predict since it even has access to the token itself. To make sure the model learns to predict masked tokens instead of simply copying input tokens to output tokens, we indicate that no loss should be computed the 85% of non-replaced tokens by setting their label to -100.

@flax.struct.dataclass class FlaxDataCollatorForMaskedLanguageModeling: mlm_probability: float = 0.15 def __call__(self, examples, tokenizer, pad_to_multiple_of=16): batch = tokenizer.pad(examples, return_tensors="np", pad_to_multiple_of=pad_to_multiple_of) special_tokens_mask = batch.pop("special_tokens_mask", None) batch["input_ids"], batch["labels"] = self.mask_tokens( batch["input_ids"], special_tokens_mask, tokenizer ) return batch def mask_tokens(self, inputs, special_tokens_mask, tokenizer): labels = inputs.copy() # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) probability_matrix = np.full(labels.shape, self.mlm_probability) special_tokens_mask = special_tokens_mask.astype("bool") probability_matrix[special_tokens_mask] = 0.0 masked_indices = np.random.binomial(1, probability_matrix).astype("bool") labels[~masked_indices] = -100 # We only compute loss on masked tokens # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) # 10% of the time, we replace masked input tokens with random word indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool") indices_random &= masked_indices & ~indices_replaced random_words = np.random.randint(tokenizer.vocab_size, size=labels.shape, dtype="i4") inputs[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input tokens unchanged return inputs, labels

Having defined the MLM data collator, we can now instantiate one.

data_collator = FlaxDataCollatorForMaskedLanguageModeling(mlm_probability=0.15)

At each training epoch, the dataset should be shuffled and superfluous samples that make the dataset not evenly divisible by the batch size are thrown away. Instead of passing the dataset, we prepare the indices of data samples to be used for both each training epoch. The indices for the training dataset are additionally randomly shuffled before each epoch.

def generate_batch_splits(num_samples, batch_size, rng=None): samples_idx = jax.numpy.arange(num_samples) # if random seed is provided, then shuffle the dataset if input_rng is not None: samples_idx = jax.random.permutation(input_rng, samples_idx) samples_to_remove = num_samples % batch_size # throw away incomplete batch if samples_to_remove != 0: samples_idx = samples_idx[:-samples_to_remove] batch_idx = np.split(samples_idx, num_samples // batch_size) return batch_idx

During fine-tuning, we want to update the model parameters and evaluate the performance after each epoch.

Let's write the functions train_step and eval_step accordingly. During training the weight parameters should be updated as follows:

  1. Define a loss function loss_function that first runs a forward pass of the model given data input. Remember that Flax models are immutable, and we explicitly pass it the state (in this case the model parameters and the RNG). loss_function returns a scalar loss (using the previously defined state.loss_function) between the model output and input targets.

  2. Differentiate this loss function using jax.value_and_grad. This is a JAX transformation called automatic differentiation, which computes the gradient of loss_function given the input to the function (i.e., the parameters of the model), and returns the value and the gradient in a pair (loss, gradients).

  3. Compute the mean gradient over all devices using the collective operation lax.pmean. As we will see below, each device runs train_step on a different batch of data, but by taking the mean here we ensure the model parameters are the same on all devices.

  4. Use state.apply_gradients, which applies the gradients to the weights.

Below, you can see how each of the described steps above is put into practice.

def train_step(state, batch, dropout_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def loss_fn(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] # compute loss, ignore padded input tokens label_mask = jax.numpy.where(labels > 0, 1.0, 0.0) loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask # take average loss = loss.sum() / label_mask.sum() return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" ) return new_state, metrics, new_dropout_rng

Now, we want to do parallelized training over all TPU devices. To do so, we use jax.pmap. This will compile the function once and run the same program on each device (it is an SPMD program). When calling this pmapped function, all inputs ("state", "batch", "dropout_rng") should be replicated for all devices, which means that the first axis of each argument is used to map over all TPU devices.

parallel_train_step = jax.pmap(train_step, "batch")

Similarly, we can now define the evaluation step. Here, the function is much easier as we don't need to compute any gradients. To better monitor the performance improvement during training, the accuracy is computed alongside the loss and stored in a metric dictionary during evaluation.

def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] label_mask = jax.numpy.where(labels > 0, 1.0, 0.0) loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask # compute accuracy accuracy = jax.numpy.equal(jax.numpy.argmax(logits, axis=-1), labels) * label_mask # summarize metrics metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()} metrics = jax.lax.psum(metrics, axis_name="batch") return metrics

Similarly, we also apply jax.pmap to the evaluation step.

parallel_eval_step = jax.pmap(eval_step, "batch")

Next, we replicate/copy the weight parameters on each device, so that we can pass them to our parallelized mapped functions.

state = flax.jax_utils.replicate(state)
/usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:317: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code. "jax.host_count has been renamed to jax.process_count. This alias " /usr/local/lib/python3.7/dist-packages/jax/lib/xla_bridge.py:304: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code. "jax.host_id has been renamed to jax.process_index. This alias "

To monitor the performance during training, we accumulate the loss and the accuracy of each evaluation step. Because the loss is not computed on most input tokens, we need to normalize the accuracy and loss before computing the average.

Let's wrap this logit into a process_eval_metrics function to not clutter the training loop too much.

def process_eval_metrics(metrics): metrics = get_metrics(metrics) metrics = jax.tree_map(jax.numpy.sum, metrics) normalizer = metrics.pop("normalizer") metrics = jax.tree_map(lambda x: x / normalizer, metrics) return metrics

We can almost start training! In a final preparation step, we generate a seeded PRNGKey used as the random seed for dropout layers and dataset shuffling.

Similar to how we had to copy/replicate the state on all 8 TPU devices, we also need to generate one PRNGKey per device, which is why we split the initial rng key into 8 random seeds.

rng = jax.random.PRNGKey(training_seed) dropout_rngs = jax.random.split(rng, jax.local_device_count())

Now, we are all set to finally start training! Let's put all the pieces together and write the training loop.

We start each epoch by generating a new random seed that will be used for dataset shuffling, the dropout layers and the input token masking.

Next, we generate the training dataset indices. In the first nested loop - the training loop - we shard the input batch on all 8 TPU devices, and run the training step.

Analogs, in the second nested loop - the evaluation loop - the evaluation batches are sharded and the evaluation step is run.

Note: It might seem that the following cell "hangs" when executed for the first time. This is because JAX first traces & compiles the code, the very first time it is run. After the first training step, you should notice that execution is much faster.

for epoch in tqdm(range(1, num_epochs + 1), desc=f"Epoch ...", position=0, leave=True): rng, input_rng = jax.random.split(rng) # -- Train -- train_batch_idx = generate_batch_splits(len(tokenized_datasets["train"]), total_batch_size, rng=input_rng) with tqdm(total=len(train_batch_idx), desc="Training...", leave=False) as progress_bar_train: for batch_idx in train_batch_idx: model_inputs = data_collator(tokenized_datasets["train"][batch_idx], tokenizer=tokenizer, pad_to_multiple_of=16) # Model forward model_inputs = shard(model_inputs.data) state, train_metric, dropout_rngs = parallel_train_step(state, model_inputs, dropout_rngs) progress_bar_train.update(1) progress_bar_train.write( f"Train... ({epoch}/{num_epochs} | Loss: {round(train_metric['loss'].mean(), 3)}, Learning Rate: {round(train_metric['learning_rate'].mean(), 6)})" ) # -- Eval -- eval_batch_idx = generate_batch_splits(len(tokenized_datasets["validation"]), total_batch_size) eval_metrics = [] with tqdm(total=len(eval_batch_idx), desc="Evaluation...", leave=False) as progress_bar_eval: for batch_idx in eval_batch_idx: model_inputs = data_collator(tokenized_datasets["validation"][batch_idx], tokenizer=tokenizer) # Model forward model_inputs = shard(model_inputs.data) eval_metric = parallel_eval_step(state.params, model_inputs) eval_metrics.append(eval_metric) progress_bar_eval.update(1) eval_metrics_dict = process_eval_metrics(eval_metrics) progress_bar_eval.write( f"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics_dict['loss']}, Acc: {eval_metrics_dict['accuracy']})" )
Train... (1/10 | Loss: 8.718000411987305, Learning Rate: 4.5000000682193786e-05)
Eval... (1/10 | Loss: 8.744632720947266, Acc: 0.048040375113487244)

It can be seen that in this colab training already reaches a speed of 1.26 training steps per second. Executing run_mlm_flax.py on a TPUv3-8 VM should be as fast as 6 training steps per second.

For a more in-detail comparison of runtimes please refer to this table.