Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/examples/benchmark.ipynb
Views: 2535
How to benchmark models with Transformers
With ever-larger language models, it is no longer enough to just compare models on their performance on a specific task. One should always be aware of the computational cost that is attached to a specific model. For a given computation environment (e.g. type of GPU), the computational cost of training a model or deploying it in inference usually depends only on the required memory and the required time.
Being able to accurately benchmark language models on both speed and required memory is therefore very important.
HuggingFace's Transformer library allows users to benchmark models for both TensorFlow 2 and PyTorch using the PyTorchBenchmark
and TensorFlowBenchmark
classes.
The currently available features for PyTorchBenchmark
are summarized in the following table.
CPU | CPU + torchscript | GPU | GPU + torchscript | GPU + FP16 | TPU | |
---|---|---|---|---|---|---|
Speed - Inference | ✔ | ✔ | ✔ | ✔ | ✔ | ✔ |
Memory - Inference | ✔ | ✔ | ✔ | ✔ | ✔ | ✘ |
Speed - Train | ✔ | ✘ | ✔ | ✘ | ✔ | ✔ |
Memory - Train | ✔ | ✘ | ✔ | ✘ | ✔ | ✘ |
FP16 stands for mixed-precision meaning that computations within the model are done using a mixture of 16-bit and 32-bit floating-point operations, see here for more detail.
torchscript corresponds to PyTorch's torchscript format, see here.
The currently available features for TensorFlowBenchmark
are summarized in the following table.
CPU | CPU + eager execution | GPU | GPU + eager execution | GPU + XLA | GPU + FP16 | TPU | |
---|---|---|---|---|---|---|---|
Speed - Inference | ✔ | ✔ | ✔ | ✔ | ✔ | ✘ | ✔ |
Memory - Inference | ✔ | ✔ | ✔ | ✔ | ✔ | ✘ | ✘ |
Speed - Train | ✔ | ✘ | ✔ | ✘ | ✘ | ✘ | ✔ |
Memory - Train | ✔ | ✘ | ✔ | ✘ | ✘ | ✘ | ✘ |
eager execution means that the function is run in the eager execution environment of TensorFlow 2, see here.
XLA stands for TensorFlow's Accelerated Linear Algebra (XLA) compiler, see here
FP16 stands for TensorFlow's mixed-precision package and is analogous to PyTorch's FP16 feature, see here.
Note: Benchmark training in TensorFlow is not included in v3.0.2, but available in master.
This notebook will show the user how to use PyTorchBenchmark
and TensorFlowBenchmark
for two different scenarios:
Inference - Pre-trained Model Comparison - A user wants to implement a pre-trained model in production for inference. She wants to compare different models on speed and required memory.
Training - Configuration Comparison - A user wants to train a specific model and searches that for himself most effective model configuration.
Inference - Pre-trained Model Comparison
Let's say we want to employ a question-answering model in production. The questions are expected to be of the same format as in SQuAD v2, so that the model to choose should have been fine-tuned on this dataset.
HuggingFace's new dataset webpage lets the user see all relevant information about a dataset and even links the models that have been fine-tuned on this specific dataset. Let's check out the dataset webpage of SQuAD v2 here.
Nice, we can see that there are 7 available models.
Let's assume that we have decided to restrict our pipeline to "encoder-only" models so that we are left with:
a-ware/roberta-large-squad-classification
a-ware/xlmroberta-squadv2
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2
deepset/roberta-base-squad2
mrm8488/longformer-base-4096-finetuned-squadv2
Great! In this notebook, we will now benchmark these models on both peak memory consumption and inference time to decide which model should be employed in production.
Note: None of the models has been tested on performance so that we will just assume that all models perform more or less equally well. The purpose of this notebook is not to find the best model for SQuAD v2, but to showcase how Transformers benchmarking tools can be leveraged.
First, we assume to be limited by the available GPU on this google colab, which in this copy amounts to 16 GB of RAM.
In a first step, we will check which models are the most memory-efficient ones. Let's make sure 100% of the GPU is available to us in this notebook.
Building wheel for gputil (setup.py) ... done
Gen RAM Free: 12.8 GB | Proc size: 160.0 MB
GPU RAM Free: 16280MB | Used: 0MB | Util 0% | Total 16280MB
Looks good! Now we import transformers
and download the scripts run_benchmark.py
, run_benchmark_tf.py
, and plot_csv_file.py
which can be found under transformers/examples/benchmarking
.
run_benchmark_tf.py
and run_benchmark.py
are very simple scripts leveraging the PyTorchBenchmark
and TensorFlowBenchmark
classes, respectively.
We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.
Information about the input arguments to the run_benchmark scripts can be accessed by running !python run_benchmark.py --help
for PyTorch and !python run_benchmark_tf.py --help
for TensorFlow.
Great, we are ready to run our first memory benchmark. By default, both the required memory and time for inference is enabled. To disable benchmarking on time, we add --no_speed
.
The only required parameter is --models
which expects a list of model identifiers as defined on the model hub. Here we add the five model identifiers listed above.
Next, we define the sequence_lengths
and batch_sizes
for which the peak memory is calculated.
Finally, because the results should be stored in a CSV file, the option --save_to_csv
is added and the path to save the results is added via the --inference_memory_csv_file
argument. Whenever a benchmark is run, the environment information, e.g. GPU type, library versions, ... can be saved using the --env_info_csv_file
argument.
Under plots_pt
, two files are now created: required_memory.csv
and env.csv
. Let's check out required_memory.csv
first.
Each row in the csv file lists one data point showing the peak memory usage for a given model, batch_size and sequence_length. As can be seen, some values have a NaN result meaning that an Out-of-Memory Error occurred. To better visualize the results, one can make use of the plot_csv_file.py
script.
Before, let's take a look at the information about our computation environment.
We can see all relevant information here: the PyTorch version, the Python version, the system, the type of GPU, and available RAM on the GPU, etc...
Note: A different GPU is likely assigned to a copy of this notebook, so that all of the following results may be different. It is very important to always include the environment information when benchmarking your models for both reproducibility and transparency to other users.
Alright, let's plot the results.
At this point, it is important to understand how the peak memory is measured. The benchmarking tools measure the peak memory usage the same way the command nvidia-smi
does - see here for more information. In short, all memory that is allocated for a given model identifier, batch size and sequence length is measured in a separate process. This way it can be ensured that there is no previously unreleased memory falsely included in the measurement. One should also note that the measured memory even includes the memory allocated by the CUDA driver to load PyTorch and TensorFlow and is, therefore, higher than library-specific memory measurement function, e.g. this one for PyTorch.
Alright, let's analyze the results. It can be noted that the models aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2
and deepset/roberta-base-squad2
require significantly less memory than the other three models. Besides mrm8488/longformer-base-4096-finetuned-squadv2
all models more or less follow the same memory consumption pattern with aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2
seemingly being able to better scale to larger sequence lengths. mrm8488/longformer-base-4096-finetuned-squadv2
is a Longformer model, which makes use of LocalAttention (check this blog post to learn more about local attention) so that the model scales much better to longer input sequences.
For the sake of this notebook, we assume that the longest required input will be less than 512 tokens so that we settle on the models aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2
and deepset/roberta-base-squad2
.
To better understand how many API requests of our question-answering pipeline can be run in parallel, we are interested in finding out how many batches the two models run out of memory.
Let's plot the results again, this time changing the x-axis to batch_size
however.
Interesting! aodiniz/bert_uncased_L-10_H-51
clearly scales better for higher batch sizes and does not even run out of memory for 512 tokens.
For comparison, let's run the same benchmarking on TensorFlow.
Let's see the same plot for TensorFlow.
The model implemented in TensorFlow requires more memory than the one implemented in PyTorch. Let's say for whatever reason we have decided to use TensorFlow instead of PyTorch.
The next step is to measure the inference time of these two models. Instead of disabling time measurement with --no_speed
, we will now disable memory measurement with --no_memory
.
Ok, this took some time... time measurements take much longer than memory measurements because the forward pass is called multiple times for stable results. Timing measurements leverage Python's timeit module and run 10 times the value given to the --repeat
argument (defaults to 3), so in our case 30 times.
Let's focus on the resulting plot. It becomes obvious that aodiniz/bert_uncased_L-10_H-51
is around twice as fast as deepset/roberta-base-squad2
. Given that the model is also more memory efficient and assuming that the model performs reasonably well, for the sake of this notebook we will settle on aodiniz/bert_uncased_L-10_H-51
. Our model should be able to process input sequences of up to 512 tokens. Latency time of around 2 seconds might be too long though, so let's compare the time for different batch sizes and using TensorFlows XLA package for more speed.
First of all, it can be noted that XLA reduces latency time by a factor of ca. 1.3 (which is more than observed for other models by TensorFlow here). A batch size of 64 looks like a good choice. More or less half a second for the forward pass is good enough.
Cool, now it should be straightforward to benchmark your favorite models. All the inference time measurements can also be done using the run_benchmark.py
script for PyTorch.
Training - Configuration Comparison
Next, we will look at how a model can be benchmarked on different configurations. This is especially helpful when one wants to decide how to most efficiently choose the model's configuration parameters for training. In the following different configurations of a Bart MNLI model will be compared to each other using PyTorchBenchmark
.
Training in PyTorchBenchmark
is defined by running one forward pass to compute the loss: loss = model(input_ids, labels=labels)[0]
and one backward pass to compute the gradients loss.backward()
.
Let's see how to most efficiently train a Bart MNLI model from scratch.
For the sake of the notebook, we assume that we are looking for a more efficient version of Facebook's bart-large-mnli
model. Let's load its configuration and check out the important parameters.
Alright! The important configuration parameters are usually the number of layers config.encoder_num_layers
and config.decoder_num_layers
, the model's hidden size: config.d_model
, the number of attention heads config.encoder_attention_heads
and config.decoder_attention_heads
and the vocabulary size config.vocab_size
.
Let's create 4 configurations different from the baseline and see how they compare in terms of peak memory consumption.
Cool, now we can benchmark these configs against the baseline config. This time, instead of using the benchmarking script we will directly use the PyTorchBenchmark
class. The class expects the argument args
which has to be of type PyTorchBenchmarkArguments
and optionally a list of configs.
First, we define the args
and give the different configurations appropriate model names. The model names must be in the same order as the configs that are directly passed to PyTorchBenchMark
.
If no configs
are provided to PyTorchBenchmark
, it is assumed that the model names ["bart-base", "bart-768-hid", "bart-8-head", "bart-10000-voc", "bart-8-lay"]
correspond to official model identifiers and their corresponding configs are loaded as was shown in the previous section.
It is assumed that the model will be trained on half-precision, so we add the option fp16=True
for the following benchmarks.
Nice, let's plot the results again.
As expected the model of the baseline config requires the most memory.
It is interesting to see that the "bart-8-head" model initially requires more memory than bart-10000-voc
, but then clearly outperforms bart-10000-voc
at an input length of 512. Less surprising is that the "bart-8-lay" is by far the most memory-efficient model when reminding oneself that during the forward pass every layer has to store its activations for the backward pass.
Alright, given the data above, let's say we narrow our candidates down to only the "bart-8-head" and "bart-8-lay" models.
Let's compare these models again on training time.
The option no_multi_process
disabled multi-processing here. This option should in general only be used for testing or debugging. Enabling multi-processing is crucial to ensure accurate memory consumption measurement, but is less important when only measuring speed. The main reason it is disabled here is that google colab sometimes raises "CUDA initialization" due to the notebook's environment. This problem does not arise when running benchmarks outside of a notebook.
Alright, let's plot the last speed results as well.
Unsurprisingly, "bart-8-lay" is faster than "bart-8-head" by a factor of ca. 1.3. It might very well be that reducing the layers by a factor of 2 leads to much more performance degradation than reducing the number of heads by a factor of 2. For more information on computational efficient Bart models, check out the new distilbart model here
Alright, that's it! Now you should be able to benchmark your favorite models on your favorite configurations.
Feel free to share your results with the community here or by tweeting us https://twitter.com/HuggingFace 🤗.