CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/onnx-export.ipynb
Views: 2535
Kernel: Python 3 (ipykernel)

How to export 🤗 Transformers Models to ONNX ?

ONNX is open format for machine learning models. It allows you to save your neural network's computation graph in a framework agnostic way, which might be particularly helpful when deploying deep learning models.

Indeed, businesses might have other requirements (languages, hardware, ...) for which the training framework might not be the best suited in inference scenarios. In that context, having a representation of the actual computation graph that can be shared accross various business units and logics across an organization might be a desirable component.

Along with the serialization format, ONNX also provides a runtime library which allows efficient and hardware specific execution of the ONNX graph. This is done through the onnxruntime project and already includes collaborations with many hardware vendors to seamlessly deploy models on various platforms.

Through this notebook we'll walk you through the process to convert a PyTorch or TensorFlow transformers model to the ONNX and leverage onnxruntime to run inference tasks on models from 🤗 transformers

Exporting 🤗 transformers model to ONNX


Exporting models (either PyTorch or TensorFlow) is easily achieved through the conversion tool provided as part of 🤗 transformers repository.

Under the hood the process is sensibly the following:

  1. Allocate the model from transformers (PyTorch or TensorFlow)

  2. Forward dummy inputs through the model this way ONNX can record the set of operations executed

  3. Optionally define dynamic axes on input and output tensors

  4. Save the graph along with the network parameters

import sys !{sys.executable} -m pip install --upgrade git+https://github.com/huggingface/transformers !{sys.executable} -m pip install --upgrade torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html !{sys.executable} -m pip install --upgrade onnxruntime==1.4.0 !{sys.executable} -m pip install -i https://test.pypi.org/simple/ ort-nightly !{sys.executable} -m pip install --upgrade onnxruntime-tools
Collecting git+https://github.com/huggingface/transformers Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-9rvbp9p8 Running command git clone -q https://github.com/huggingface/transformers /tmp/pip-req-build-9rvbp9p8 Requirement already satisfied, skipping upgrade: numpy in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers==3.0.2) (1.18.1) Requirement already satisfied, skipping upgrade: tokenizers==0.8.1.rc2 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers==3.0.2) (0.8.1rc2) Requirement already satisfied, skipping upgrade: packaging in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers==3.0.2) (20.4) Requirement already satisfied, skipping upgrade: filelock in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers==3.0.2) (3.0.12) Requirement already satisfied, skipping upgrade: requests in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers==3.0.2) (2.23.0) Requirement already satisfied, skipping upgrade: tqdm>=4.27 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers==3.0.2) (4.46.1) Requirement already satisfied, skipping upgrade: regex!=2019.12.17 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers==3.0.2) (2020.6.8) Requirement already satisfied, skipping upgrade: sentencepiece!=0.1.92 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers==3.0.2) (0.1.91) Requirement already satisfied, skipping upgrade: sacremoses in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers==3.0.2) (0.0.43) Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from packaging->transformers==3.0.2) (2.4.7) Requirement already satisfied, skipping upgrade: six in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from packaging->transformers==3.0.2) (1.15.0) Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from requests->transformers==3.0.2) (3.0.4) Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from requests->transformers==3.0.2) (2.9) Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from requests->transformers==3.0.2) (1.25.9) Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from requests->transformers==3.0.2) (2020.6.20) Requirement already satisfied, skipping upgrade: click in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from sacremoses->transformers==3.0.2) (7.1.2) Requirement already satisfied, skipping upgrade: joblib in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from sacremoses->transformers==3.0.2) (0.15.1) Building wheels for collected packages: transformers Building wheel for transformers (setup.py) ... done Created wheel for transformers: filename=transformers-3.0.2-py3-none-any.whl size=883063 sha256=5f2caef76450921ae2e5b10abbbaab436e9c87c83486114fa08d305e4396d4cd Stored in directory: /tmp/pip-ephem-wheel-cache-kftypcjz/wheels/42/68/45/c63edff61c292f2dfd4df4ef6522dcbecc603e7af82813c1d7 Successfully built transformers Installing collected packages: transformers Attempting uninstall: transformers Found existing installation: transformers 3.0.2 Uninstalling transformers-3.0.2: Successfully uninstalled transformers-3.0.2 Successfully installed transformers-3.0.2 Looking in links: https://download.pytorch.org/whl/torch_stable.html Requirement already up-to-date: torch==1.6.0+cpu in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (1.6.0+cpu) Requirement already up-to-date: torchvision==0.7.0+cpu in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (0.7.0+cpu) Requirement already satisfied, skipping upgrade: numpy in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from torch==1.6.0+cpu) (1.18.1) Requirement already satisfied, skipping upgrade: future in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from torch==1.6.0+cpu) (0.18.2) Requirement already satisfied, skipping upgrade: pillow>=4.1.1 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from torchvision==0.7.0+cpu) (7.2.0) Requirement already up-to-date: onnxruntime==1.4.0 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (1.4.0) Requirement already satisfied, skipping upgrade: protobuf in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnxruntime==1.4.0) (3.12.2) Requirement already satisfied, skipping upgrade: numpy>=1.16.6 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnxruntime==1.4.0) (1.18.1) Requirement already satisfied, skipping upgrade: setuptools in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from protobuf->onnxruntime==1.4.0) (47.1.1.post20200604) Requirement already satisfied, skipping upgrade: six>=1.9 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from protobuf->onnxruntime==1.4.0) (1.15.0) Looking in indexes: https://test.pypi.org/simple/ Requirement already satisfied: ort-nightly in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (1.4.0.dev202008262) Requirement already satisfied: protobuf in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from ort-nightly) (3.12.2) Requirement already satisfied: numpy>=1.16.6 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from ort-nightly) (1.18.1) Requirement already satisfied: setuptools in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from protobuf->ort-nightly) (47.1.1.post20200604) Requirement already satisfied: six>=1.9 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from protobuf->ort-nightly) (1.15.0) Requirement already up-to-date: onnxruntime-tools in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (1.4.2) Requirement already satisfied, skipping upgrade: numpy in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnxruntime-tools) (1.18.1) Requirement already satisfied, skipping upgrade: coloredlogs in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnxruntime-tools) (14.0) Requirement already satisfied, skipping upgrade: py3nvml in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnxruntime-tools) (0.2.6) Requirement already satisfied, skipping upgrade: psutil in /home/mfuntowicz/.local/lib/python3.8/site-packages/psutil-5.7.0-py3.8-linux-x86_64.egg (from onnxruntime-tools) (5.7.0) Requirement already satisfied, skipping upgrade: packaging in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnxruntime-tools) (20.4) Requirement already satisfied, skipping upgrade: py-cpuinfo in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnxruntime-tools) (5.0.0) Requirement already satisfied, skipping upgrade: onnx in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnxruntime-tools) (1.7.0) Requirement already satisfied, skipping upgrade: humanfriendly>=7.1 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from coloredlogs->onnxruntime-tools) (8.2) Requirement already satisfied, skipping upgrade: xmltodict in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from py3nvml->onnxruntime-tools) (0.12.0) Requirement already satisfied, skipping upgrade: six in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from packaging->onnxruntime-tools) (1.15.0) Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from packaging->onnxruntime-tools) (2.4.7) Requirement already satisfied, skipping upgrade: typing-extensions>=3.6.2.1 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnx->onnxruntime-tools) (3.7.4.2) Requirement already satisfied, skipping upgrade: protobuf in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnx->onnxruntime-tools) (3.12.2) Requirement already satisfied, skipping upgrade: setuptools in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from protobuf->onnx->onnxruntime-tools) (47.1.1.post20200604)

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

from transformers.utils import send_example_telemetry send_example_telemetry("onnx_export_notebook", framework="pytorch")
!rm -rf onnx/ from pathlib import Path from transformers.convert_graph_to_onnx import convert # Handles all the above steps for you convert(framework="pt", model="bert-base-cased", output=Path("onnx/bert-base-cased.onnx"), opset=11) # Tensorflow # convert(framework="tf", model="bert-base-cased", output="onnx/bert-base-cased.onnx", opset=11)
loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /home/mfuntowicz/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391 Model config BertConfig { "architectures": [ "BertForMaskedLM" ], "attention_probs_dropout_prob": 0.1, "gradient_checkpointing": false, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "type_vocab_size": 2, "vocab_size": 28996 }
ONNX opset version set to: 11 Loading pipeline (model: bert-base-cased, tokenizer: bert-base-cased)
loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /home/mfuntowicz/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1 loading model card file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-modelcard.json from cache at /home/mfuntowicz/.cache/torch/transformers/72b46f187c40a666d54782e06684c2870e109350a3efe9aa5027253dec2e671d.455d944f3d1572ab55ed579849f751cf37f303e3388980a42d94f7cd57a4e331 Model card: { "caveats_and_recommendations": {}, "ethical_considerations": {}, "evaluation_data": {}, "factors": {}, "intended_use": {}, "metrics": {}, "model_details": {}, "quantitative_analyses": {}, "training_data": {} } loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /home/mfuntowicz/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391 Model config BertConfig { "architectures": [ "BertForMaskedLM" ], "attention_probs_dropout_prob": 0.1, "gradient_checkpointing": false, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "type_vocab_size": 2, "vocab_size": 28996 } loading weights file https://cdn.huggingface.co/bert-base-cased-pytorch_model.bin from cache at /home/mfuntowicz/.cache/torch/transformers/d8f11f061e407be64c4d5d7867ee61d1465263e24085cfa26abf183fdc830569.3fadbea36527ae472139fe84cddaa65454d7429f12d543d80bfc3ad70de55ac2 All model checkpoint weights were used when initializing BertModel. All the weights of BertModel were initialized from the model checkpoint at bert-base-cased. If your task is similar to the task the model of the checkpoint was trained on, you can already use BertModel for predictions without further training. /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/modeling_bert.py:201: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! position_ids = self.position_ids[:, :seq_length]
Creating folder onnx Using framework PyTorch: 1.6.0 Found input input_ids with shape: {0: 'batch', 1: 'sequence'} Found input token_type_ids with shape: {0: 'batch', 1: 'sequence'} Found input attention_mask with shape: {0: 'batch', 1: 'sequence'} Found output output_0 with shape: {0: 'batch', 1: 'sequence'} Found output output_1 with shape: {0: 'batch'} Ensuring inputs are in correct order position_ids is not present in the generated input list. Generated inputs order: ['input_ids', 'attention_mask', 'token_type_ids']
/home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/modeling_utils.py:1570: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! input_tensor.shape == tensor_shape for input_tensor in input_tensors

How to leverage runtime for inference over an ONNX graph


As mentionned in the introduction, ONNX is a serialization format and many side projects can load the saved graph and run the actual computations from it. Here, we'll focus on the official onnxruntime. The runtime is implemented in C++ for performance reasons and provides API/Bindings for C++, C, C#, Java and Python.

In the case of this notebook, we will use the Python API to highlight how to load a serialized ONNX graph and run inference workload on various backends through onnxruntime.

onnxruntime is available on pypi:

  • onnxruntime: ONNX + MLAS (Microsoft Linear Algebra Subprograms)

  • onnxruntime-gpu: ONNX + MLAS + CUDA

!pip install transformers onnxruntime-gpu onnx psutil matplotlib
Requirement already satisfied: transformers in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (3.0.2) Requirement already satisfied: onnxruntime-gpu in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (1.3.0) Requirement already satisfied: onnx in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (1.7.0) Requirement already satisfied: psutil in /home/mfuntowicz/.local/lib/python3.8/site-packages/psutil-5.7.0-py3.8-linux-x86_64.egg (5.7.0) Requirement already satisfied: matplotlib in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (3.3.1) Requirement already satisfied: tqdm>=4.27 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers) (4.46.1) Requirement already satisfied: numpy in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers) (1.18.1) Requirement already satisfied: sacremoses in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers) (0.0.43) Requirement already satisfied: regex!=2019.12.17 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers) (2020.6.8) Requirement already satisfied: filelock in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers) (3.0.12) Requirement already satisfied: sentencepiece!=0.1.92 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers) (0.1.91) Requirement already satisfied: requests in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers) (2.23.0) Requirement already satisfied: packaging in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers) (20.4) Requirement already satisfied: tokenizers==0.8.1.rc2 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from transformers) (0.8.1rc2) Requirement already satisfied: protobuf in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnxruntime-gpu) (3.12.2) Requirement already satisfied: six in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnx) (1.15.0) Requirement already satisfied: typing-extensions>=3.6.2.1 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from onnx) (3.7.4.2) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from matplotlib) (2.4.7) Requirement already satisfied: kiwisolver>=1.0.1 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from matplotlib) (1.2.0) Requirement already satisfied: python-dateutil>=2.1 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from matplotlib) (2.8.1) Requirement already satisfied: cycler>=0.10 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from matplotlib) (0.10.0) Requirement already satisfied: pillow>=6.2.0 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from matplotlib) (7.2.0) Requirement already satisfied: certifi>=2020.06.20 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from matplotlib) (2020.6.20) Requirement already satisfied: click in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from sacremoses->transformers) (7.1.2) Requirement already satisfied: joblib in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from sacremoses->transformers) (0.15.1) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from requests->transformers) (1.25.9) Requirement already satisfied: chardet<4,>=3.0.2 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from requests->transformers) (3.0.4) Requirement already satisfied: idna<3,>=2.5 in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from requests->transformers) (2.9) Requirement already satisfied: setuptools in /home/mfuntowicz/miniconda3/envs/pytorch/lib/python3.8/site-packages (from protobuf->onnxruntime-gpu) (47.1.1.post20200604)

Preparing for an Inference Session


Inference is done using a specific backend definition which turns on hardware specific optimizations of the graph.

Optimizations are basically of three kinds:

  • Constant Folding: Convert static variables to constants in the graph

  • Deadcode Elimination: Remove nodes never accessed in the graph

  • Operator Fusing: Merge multiple instruction into one (Linear -> ReLU can be fused to be LinearReLU)

ONNX Runtime automatically applies most optimizations by setting specific SessionOptions.

Note:Some of the latest optimizations that are not yet integrated into ONNX Runtime are available in optimization script that tunes models for the best performance.

# # An optional step unless # # you want to get a model with mixed precision for perf accelartion on newer GPU # # or you are working with Tensorflow(tf.keras) models or pytorch models other than bert # !pip install onnxruntime-tools # from onnxruntime_tools import optimizer # # Mixed precision conversion for bert-base-cased model converted from Pytorch # optimized_model = optimizer.optimize_model("bert-base-cased.onnx", model_type='bert', num_heads=12, hidden_size=768) # optimized_model.convert_model_float32_to_float16() # optimized_model.save_model_to_file("bert-base-cased.onnx") # # optimizations for bert-base-cased model converted from Tensorflow(tf.keras) # optimized_model = optimizer.optimize_model("bert-base-cased.onnx", model_type='bert_keras', num_heads=12, hidden_size=768) # optimized_model.save_model_to_file("bert-base-cased.onnx") # optimize transformer-based models with onnxruntime-tools from onnxruntime_tools import optimizer from onnxruntime_tools.transformers.onnx_model_bert import BertOptimizationOptions # disable embedding layer norm optimization for better model size reduction opt_options = BertOptimizationOptions('bert') opt_options.enable_embed_layer_norm = False opt_model = optimizer.optimize_model( 'onnx/bert-base-cased.onnx', 'bert', num_heads=12, hidden_size=768, optimization_options=opt_options) opt_model.save_model_to_file('bert.opt.onnx')
from os import environ from psutil import cpu_count # Constants from the performance optimization available in onnxruntime # It needs to be done before importing onnxruntime environ["OMP_NUM_THREADS"] = str(cpu_count(logical=True)) environ["OMP_WAIT_POLICY"] = 'ACTIVE' from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_all_providers
from contextlib import contextmanager from dataclasses import dataclass from time import time from tqdm import trange def create_model_for_provider(model_path: str, provider: str) -> InferenceSession: assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}" # Few properties that might have an impact on performances (provided by MS) options = SessionOptions() options.intra_op_num_threads = 1 options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL # Load the model as a graph and prepare the CPU backend session = InferenceSession(model_path, options, providers=[provider]) session.disable_fallback() return session @contextmanager def track_infer_time(buffer: [int]): start = time() yield end = time() buffer.append(end - start) @dataclass class OnnxInferenceResult: model_inference_time: [int] optimized_model_path: str

Forwarding through our optimized ONNX model running on CPU


When the model is loaded for inference over a specific provider, for instance CPUExecutionProvider as above, an optimized graph can be saved. This graph will might include various optimizations, and you might be able to see some higher-level operations in the graph (through Netron for instance) such as:

  • EmbedLayerNormalization

  • Attention

  • FastGeLU

These operations are an example of the kind of optimization onnxruntime is doing, for instance here gathering multiple operations into bigger one (Operator Fusing).

from transformers import BertTokenizerFast tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") cpu_model = create_model_for_provider("onnx/bert-base-cased.onnx", "CPUExecutionProvider") # Inputs are provided through numpy array model_inputs = tokenizer("My name is Bert", return_tensors="pt") inputs_onnx = {k: v.cpu().detach().numpy() for k, v in model_inputs.items()} # Run the model (None = get all the outputs) sequence, pooled = cpu_model.run(None, inputs_onnx) # Print information about outputs print(f"Sequence output: {sequence.shape}, Pooled output: {pooled.shape}")
loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /home/mfuntowicz/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
Sequence output: (1, 6, 768), Pooled output: (1, 768)

Benchmarking PyTorch model

Note: PyTorch model benchmark is run on CPU

from transformers import BertModel PROVIDERS = { ("cpu", "PyTorch CPU"), # Uncomment this line to enable GPU benchmarking # ("cuda:0", "PyTorch GPU") } results = {} for device, label in PROVIDERS: # Move inputs to the correct device model_inputs_on_device = { arg_name: tensor.to(device) for arg_name, tensor in model_inputs.items() } # Add PyTorch to the providers model_pt = BertModel.from_pretrained("bert-base-cased").to(device) for _ in trange(10, desc="Warming up"): model_pt(**model_inputs_on_device) # Compute time_buffer = [] for _ in trange(100, desc=f"Tracking inference time on PyTorch"): with track_infer_time(time_buffer): model_pt(**model_inputs_on_device) # Store the result results[label] = OnnxInferenceResult( time_buffer, None )
loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /home/mfuntowicz/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391 Model config BertConfig { "architectures": [ "BertForMaskedLM" ], "attention_probs_dropout_prob": 0.1, "gradient_checkpointing": false, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "type_vocab_size": 2, "vocab_size": 28996 } loading weights file https://cdn.huggingface.co/bert-base-cased-pytorch_model.bin from cache at /home/mfuntowicz/.cache/torch/transformers/d8f11f061e407be64c4d5d7867ee61d1465263e24085cfa26abf183fdc830569.3fadbea36527ae472139fe84cddaa65454d7429f12d543d80bfc3ad70de55ac2 All model checkpoint weights were used when initializing BertModel. All the weights of BertModel were initialized from the model checkpoint at bert-base-cased. If your task is similar to the task the model of the checkpoint was trained on, you can already use BertModel for predictions without further training. Warming up: 100%|██████████| 10/10 [00:00<00:00, 39.30it/s] Tracking inference time on PyTorch: 100%|██████████| 100/100 [00:02<00:00, 41.09it/s]

Benchmarking PyTorch & ONNX on CPU

Disclamer: results may vary from the actual hardware used to run the model

PROVIDERS = { ("CPUExecutionProvider", "ONNX CPU"), # Uncomment this line to enable GPU benchmarking # ("CUDAExecutionProvider", "ONNX GPU") } for provider, label in PROVIDERS: # Create the model with the specified provider model = create_model_for_provider("onnx/bert-base-cased.onnx", provider) # Keep track of the inference time time_buffer = [] # Warm up the model model.run(None, inputs_onnx) # Compute for _ in trange(100, desc=f"Tracking inference time on {provider}"): with track_infer_time(time_buffer): model.run(None, inputs_onnx) # Store the result results[label] = OnnxInferenceResult( time_buffer, model.get_session_options().optimized_model_filepath )
Tracking inference time on CPUExecutionProvider: 100%|██████████| 100/100 [00:01<00:00, 63.62it/s]
%matplotlib inline import matplotlib import matplotlib.pyplot as plt import numpy as np import os # Compute average inference time + std time_results = {k: np.mean(v.model_inference_time) * 1e3 for k, v in results.items()} time_results_std = np.std([v.model_inference_time for v in results.values()]) * 1000 plt.rcdefaults() fig, ax = plt.subplots(figsize=(16, 12)) ax.set_ylabel("Avg Inference time (ms)") ax.set_title("Average inference time (ms) for each provider") ax.bar(time_results.keys(), time_results.values(), yerr=time_results_std) plt.show()
Image in a Jupyter notebook

Quantization support from transformers

Quantization enables the use of integers (instead of floatting point) arithmetic to run neural networks models faster. From a high-level point of view, quantization works as mapping the float32 ranges of values as int8 with the less loss in the performances of the model.

Hugging Face provides a conversion tool as part of the transformers repository to easily export quantized models to ONNX Runtime. For more information, please refer to the following:

With this method, the accuracy of the model remains at the same level than the full-precision model. If you want to see benchmarks on model performances, we recommand reading the ONNX Runtime notebook on the subject.

Benchmarking PyTorch quantized model

import torch # Quantize model_pt_quantized = torch.quantization.quantize_dynamic( model_pt.to("cpu"), {torch.nn.Linear}, dtype=torch.qint8 ) # Warm up model_pt_quantized(**model_inputs) # Benchmark PyTorch quantized model time_buffer = [] for _ in trange(100): with track_infer_time(time_buffer): model_pt_quantized(**model_inputs) results["PyTorch CPU Quantized"] = OnnxInferenceResult( time_buffer, None )
100%|██████████| 100/100 [00:01<00:00, 90.15it/s]

Benchmarking ONNX quantized model

from transformers.convert_graph_to_onnx import quantize # Transformers allow you to easily convert float32 model to quantized int8 with ONNX Runtime quantized_model_path = quantize(Path("bert.opt.onnx")) # Then you just have to load through ONNX runtime as you would normally do quantized_model = create_model_for_provider(quantized_model_path.as_posix(), "CPUExecutionProvider") # Warm up the overall model to have a fair comparaison outputs = quantized_model.run(None, inputs_onnx) # Evaluate performances time_buffer = [] for _ in trange(100, desc=f"Tracking inference time on CPUExecutionProvider with quantized model"): with track_infer_time(time_buffer): outputs = quantized_model.run(None, inputs_onnx) # Store the result results["ONNX CPU Quantized"] = OnnxInferenceResult( time_buffer, quantized_model_path )
As of onnxruntime 1.4.0, models larger than 2GB will fail to quantize due to protobuf constraint. This limitation will be removed in the next release of onnxruntime. Quantized model has been written at bert.onnx: ✔
Tracking inference time on CPUExecutionProvider with quantized model: 100%|██████████| 100/100 [00:00<00:00, 237.49it/s]

Show the inference performance of each providers

%matplotlib inline import matplotlib import matplotlib.pyplot as plt import numpy as np import os # Compute average inference time + std time_results = {k: np.mean(v.model_inference_time) * 1e3 for k, v in results.items()} time_results_std = np.std([v.model_inference_time for v in results.values()]) * 1000 plt.rcdefaults() fig, ax = plt.subplots(figsize=(16, 12)) ax.set_ylabel("Avg Inference time (ms)") ax.set_title("Average inference time (ms) for each provider") ax.bar(time_results.keys(), time_results.values(), yerr=time_results_std) plt.show()
Image in a Jupyter notebook