Path: blob/main/beginner_source/onnx/onnx_registry_tutorial.py
1695 views
# -*- coding: utf-8 -*-1"""2`Introduction to ONNX <intro_onnx.html>`_ ||3`Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ ||4**Extending the ONNX exporter operator support** ||5`Export a model with control flow to ONNX <export_control_flow_model_to_onnx_tutorial.html>`_67Extending the ONNX Exporter Operator Support8============================================910**Authors:** `Ti-Tai Wang <[email protected]>`_, `Justin Chu <[email protected]>`_11"""121314###############################################################################15# Overview16# --------17#18# This tutorial describes how you can create ONNX implementation for unsupported PyTorch operators19# or replace existing implementation with your own.20#21# We will cover three scenarios that require extending the ONNX exporter's operator support:22#23# * Overriding the implementation of an existing PyTorch operator24# * Using custom ONNX operators25# * Supporting a custom PyTorch operator26#27# What you will learn:28#29# - How to override or add support for PyTorch operators in ONNX.30# - How to integrate custom ONNX operators for specialized runtimes.31# - How to implement and translate custom PyTorch operators to ONNX.32#33# Prerequisites34# ~~~~~~~~~~~~~35#36# Before starting this tutorial, make sure you have completed the following prerequisites:37#38# * ``torch >= 2.6``39# * The target PyTorch operator40# * Completed the41# `ONNX Script tutorial <https://github.com/microsoft/onnxscript/blob/main/docs/tutorial/index.md>`_42# before proceeding43# * The implementation of the operator using `ONNX Script <https://github.com/microsoft/onnxscript>`__44#45# Overriding the implementation of an existing PyTorch operator46# -------------------------------------------------------------47#48# Although the ONNX exporter team does their best efforts to support all PyTorch operators, some of them49# might not be supported yet. In this section, we will demonstrate how you can add50# unsupported PyTorch operators to the ONNX Registry.51#52# .. note::53# The steps to implement unsupported PyTorch operators are the same as those for replacing the implementation of an existing54# PyTorch operator with a custom one.55# Because we don't actually have an unsupported PyTorch operator to use in this tutorial, we are going to leverage56# this and replace the implementation of ``torch.ops.aten.add.Tensor`` with a custom implementation the same way we would57# if the operator was not implemented by the ONNX exporter.58#59# When a model cannot be exported to ONNX due to an unsupported operator, the ONNX exporter will show an error message60# similar to:61#62# .. code-block:: python63#64# No decompositions registered for [...]65#66# The error message indicates that the unsupported PyTorch operator is ``torch.ops.aten.add.Tensor``.67# The operator is of type ``<class 'torch._ops.OpOverload'>``, and this operator is what we will use as the68# target to register our custom implementation.6970import torch71import onnxscript7273# Opset 18 is the standard supported version as of PyTorch 2.674from onnxscript import opset18 as op757677# Create a model that uses the operator torch.ops.aten.add.Tensor78class Model(torch.nn.Module):79def forward(self, input_x, input_y):80return torch.ops.aten.add.Tensor(input_x, input_y)818283# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.84# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml85# All attributes must be annotated with type hints.86def custom_aten_add(self, other, alpha: float = 1.0):87if alpha != 1.0:88alpha = op.CastLike(alpha, other)89other = op.Mul(other, alpha)90# To distinguish the custom implementation from the builtin one, we switch the order of the inputs91return op.Add(other, self)929394x = torch.tensor([1.0])95y = torch.tensor([2.0])9697# Then we provide the custom implementation to the ONNX exporter as a ``custom_translation_table``.98onnx_program = torch.onnx.export(99Model().eval(),100(x, y),101dynamo=True,102custom_translation_table={103torch.ops.aten.add.Tensor: custom_aten_add,104},105)106# Optimize the ONNX graph to remove redundant nodes107onnx_program.optimize()108109######################################################################110# Now let's inspect the model and verify the model is using the custom implementation.111112print(onnx_program.model)113114######################################################################115# The translation is using our custom implementation: In node ``node_Add_0``, ``input_y`` now116# comes first, and ``input_x`` comes second.117#118# We can use ONNX Runtime to run the model and verify the results by calling119# the :class:`torch.onnx.ONNXProgram` directly on the input tensors.120121result = onnx_program(x, y)[0]122torch.testing.assert_close(result, torch.tensor([3.0]))123124125######################################################################126# Using custom ONNX operators127# ---------------------------128#129# In this case, we create a model with standard PyTorch operators, but the runtime130# (such as Microsoft's ONNX Runtime) can provide a custom implementation for that kernel, effectively replacing the131# existing implementation.132#133# In the following example, we use the ``com.microsoft.Gelu`` operator provided by ONNX Runtime,134# which is not the same ``Gelu`` from ONNX spec.135136137class GeluModel(torch.nn.Module):138def forward(self, input_x):139return torch.ops.aten.gelu(input_x)140141142# Create a namespace for the custom operator using ONNX Script143# ``com.microsoft`` is an official ONNX Runtime namespace144microsoft_op = onnxscript.values.Opset(domain="com.microsoft", version=1)145146# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.147# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml148# NOTE: All attributes must be annotated with type hints.149# The function must be scripted using the ``@onnxscript.script()`` decorator when150# using operators from custom domains. This may be improved in future versions.151from onnxscript import FLOAT152153154@onnxscript.script(microsoft_op)155def custom_aten_gelu(self: FLOAT, approximate: str = "none") -> FLOAT:156return microsoft_op.Gelu(self)157158159onnx_program = torch.onnx.export(160GeluModel().eval(),161(x,),162dynamo=True,163custom_translation_table={164torch.ops.aten.gelu.default: custom_aten_gelu,165},166)167168# Optimize the ONNX graph to remove redundant nodes169onnx_program.optimize()170171172######################################################################173# Let's inspect the model and verify the model uses op_type ``Gelu``174# from namespace ``com.microsoft``.175#176177print(onnx_program.model)178179######################################################################180# Similar to the previous example, we can use ONNX Runtime to run the model and verify the results.181182result = onnx_program(x)[0]183torch.testing.assert_close(result, torch.ops.aten.gelu(x))184185186######################################################################187# Supporting a custom PyTorch operator188# ------------------------------------189#190# In this case, the operator is an operator that is user implemented and registered to PyTorch.191#192# In the following example, we would like to use a custom operator193# that takes one tensor input, and returns one output. The operator adds194# the input to itself, and returns the rounded result.195#196# Firstly, we assume the custom operator is implemented and registered with ``torch.library.custom_op()``.197# You can refer to `Creating new custom ops in Python <https://pytorch.org/docs/stable/library.html#torch.library.custom_op>`_198# for a detailed guide on how to create custom operators.199200201# Define and use the operator in PyTorch202@torch.library.custom_op("mylibrary::add_and_round_op", mutates_args=())203def add_and_round_op(input: torch.Tensor) -> torch.Tensor:204return torch.round(input + input)205206207@add_and_round_op.register_fake208def _add_and_round_op_fake(tensor_x):209return torch.empty_like(tensor_x)210211212class AddAndRoundModel(torch.nn.Module):213def forward(self, input):214return add_and_round_op(input)215216217# Implement the custom operator in ONNX using ONNX Script218def onnx_add_and_round(input):219return op.Round(op.Add(input, input))220221222onnx_program = torch.onnx.export(223AddAndRoundModel().eval(),224(x,),225dynamo=True,226custom_translation_table={227torch.ops.mylibrary.add_and_round_op.default: onnx_add_and_round,228},229)230231# Optimize the ONNX graph to remove redundant nodes232onnx_program.optimize()233print(onnx_program)234235######################################################################236# The translation is using our custom implementation to translate the ``torch.ops.mylibrary.add_and_round_op.default``237# operator in the :class:`torch.export.ExportedProgram`` to the ONNX operator ``Add`` and ``Round``.238#239240######################################################################241# Finally we verify the results.242243result = onnx_program(x)[0]244torch.testing.assert_close(result, add_and_round_op(x))245246######################################################################247# Conclusion248# ----------249#250# Congratulations! In this tutorial, we explored the ``custom_translation_table`` option and251# discovered how to create custom implementations for unsupported or existing PyTorch operators252# using ONNX Script.253#254# Finally, we leveraged ONNX Runtime to execute the model and compare the results with PyTorch,255# providing us with a comprehensive understanding of handling unsupported256# operators in the ONNX ecosystem.257#258# Further reading259# ---------------260#261# The list below refers to tutorials that ranges from basic examples to advanced scenarios,262# not necessarily in the order they are listed.263# Feel free to jump directly to specific topics of your interest or264# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.265#266# .. include:: /beginner_source/onnx/onnx_toc.txt267#268# .. toctree::269# :hidden:270#271272273