Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/onnx/onnx_registry_tutorial.py
1695 views
1
# -*- coding: utf-8 -*-
2
"""
3
`Introduction to ONNX <intro_onnx.html>`_ ||
4
`Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ ||
5
**Extending the ONNX exporter operator support** ||
6
`Export a model with control flow to ONNX <export_control_flow_model_to_onnx_tutorial.html>`_
7
8
Extending the ONNX Exporter Operator Support
9
============================================
10
11
**Authors:** `Ti-Tai Wang <[email protected]>`_, `Justin Chu <[email protected]>`_
12
"""
13
14
15
###############################################################################
16
# Overview
17
# --------
18
#
19
# This tutorial describes how you can create ONNX implementation for unsupported PyTorch operators
20
# or replace existing implementation with your own.
21
#
22
# We will cover three scenarios that require extending the ONNX exporter's operator support:
23
#
24
# * Overriding the implementation of an existing PyTorch operator
25
# * Using custom ONNX operators
26
# * Supporting a custom PyTorch operator
27
#
28
# What you will learn:
29
#
30
# - How to override or add support for PyTorch operators in ONNX.
31
# - How to integrate custom ONNX operators for specialized runtimes.
32
# - How to implement and translate custom PyTorch operators to ONNX.
33
#
34
# Prerequisites
35
# ~~~~~~~~~~~~~
36
#
37
# Before starting this tutorial, make sure you have completed the following prerequisites:
38
#
39
# * ``torch >= 2.6``
40
# * The target PyTorch operator
41
# * Completed the
42
# `ONNX Script tutorial <https://github.com/microsoft/onnxscript/blob/main/docs/tutorial/index.md>`_
43
# before proceeding
44
# * The implementation of the operator using `ONNX Script <https://github.com/microsoft/onnxscript>`__
45
#
46
# Overriding the implementation of an existing PyTorch operator
47
# -------------------------------------------------------------
48
#
49
# Although the ONNX exporter team does their best efforts to support all PyTorch operators, some of them
50
# might not be supported yet. In this section, we will demonstrate how you can add
51
# unsupported PyTorch operators to the ONNX Registry.
52
#
53
# .. note::
54
# The steps to implement unsupported PyTorch operators are the same as those for replacing the implementation of an existing
55
# PyTorch operator with a custom one.
56
# Because we don't actually have an unsupported PyTorch operator to use in this tutorial, we are going to leverage
57
# this and replace the implementation of ``torch.ops.aten.add.Tensor`` with a custom implementation the same way we would
58
# if the operator was not implemented by the ONNX exporter.
59
#
60
# When a model cannot be exported to ONNX due to an unsupported operator, the ONNX exporter will show an error message
61
# similar to:
62
#
63
# .. code-block:: python
64
#
65
# No decompositions registered for [...]
66
#
67
# The error message indicates that the unsupported PyTorch operator is ``torch.ops.aten.add.Tensor``.
68
# The operator is of type ``<class 'torch._ops.OpOverload'>``, and this operator is what we will use as the
69
# target to register our custom implementation.
70
71
import torch
72
import onnxscript
73
74
# Opset 18 is the standard supported version as of PyTorch 2.6
75
from onnxscript import opset18 as op
76
77
78
# Create a model that uses the operator torch.ops.aten.add.Tensor
79
class Model(torch.nn.Module):
80
def forward(self, input_x, input_y):
81
return torch.ops.aten.add.Tensor(input_x, input_y)
82
83
84
# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
85
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
86
# All attributes must be annotated with type hints.
87
def custom_aten_add(self, other, alpha: float = 1.0):
88
if alpha != 1.0:
89
alpha = op.CastLike(alpha, other)
90
other = op.Mul(other, alpha)
91
# To distinguish the custom implementation from the builtin one, we switch the order of the inputs
92
return op.Add(other, self)
93
94
95
x = torch.tensor([1.0])
96
y = torch.tensor([2.0])
97
98
# Then we provide the custom implementation to the ONNX exporter as a ``custom_translation_table``.
99
onnx_program = torch.onnx.export(
100
Model().eval(),
101
(x, y),
102
dynamo=True,
103
custom_translation_table={
104
torch.ops.aten.add.Tensor: custom_aten_add,
105
},
106
)
107
# Optimize the ONNX graph to remove redundant nodes
108
onnx_program.optimize()
109
110
######################################################################
111
# Now let's inspect the model and verify the model is using the custom implementation.
112
113
print(onnx_program.model)
114
115
######################################################################
116
# The translation is using our custom implementation: In node ``node_Add_0``, ``input_y`` now
117
# comes first, and ``input_x`` comes second.
118
#
119
# We can use ONNX Runtime to run the model and verify the results by calling
120
# the :class:`torch.onnx.ONNXProgram` directly on the input tensors.
121
122
result = onnx_program(x, y)[0]
123
torch.testing.assert_close(result, torch.tensor([3.0]))
124
125
126
######################################################################
127
# Using custom ONNX operators
128
# ---------------------------
129
#
130
# In this case, we create a model with standard PyTorch operators, but the runtime
131
# (such as Microsoft's ONNX Runtime) can provide a custom implementation for that kernel, effectively replacing the
132
# existing implementation.
133
#
134
# In the following example, we use the ``com.microsoft.Gelu`` operator provided by ONNX Runtime,
135
# which is not the same ``Gelu`` from ONNX spec.
136
137
138
class GeluModel(torch.nn.Module):
139
def forward(self, input_x):
140
return torch.ops.aten.gelu(input_x)
141
142
143
# Create a namespace for the custom operator using ONNX Script
144
# ``com.microsoft`` is an official ONNX Runtime namespace
145
microsoft_op = onnxscript.values.Opset(domain="com.microsoft", version=1)
146
147
# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
148
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
149
# NOTE: All attributes must be annotated with type hints.
150
# The function must be scripted using the ``@onnxscript.script()`` decorator when
151
# using operators from custom domains. This may be improved in future versions.
152
from onnxscript import FLOAT
153
154
155
@onnxscript.script(microsoft_op)
156
def custom_aten_gelu(self: FLOAT, approximate: str = "none") -> FLOAT:
157
return microsoft_op.Gelu(self)
158
159
160
onnx_program = torch.onnx.export(
161
GeluModel().eval(),
162
(x,),
163
dynamo=True,
164
custom_translation_table={
165
torch.ops.aten.gelu.default: custom_aten_gelu,
166
},
167
)
168
169
# Optimize the ONNX graph to remove redundant nodes
170
onnx_program.optimize()
171
172
173
######################################################################
174
# Let's inspect the model and verify the model uses op_type ``Gelu``
175
# from namespace ``com.microsoft``.
176
#
177
178
print(onnx_program.model)
179
180
######################################################################
181
# Similar to the previous example, we can use ONNX Runtime to run the model and verify the results.
182
183
result = onnx_program(x)[0]
184
torch.testing.assert_close(result, torch.ops.aten.gelu(x))
185
186
187
######################################################################
188
# Supporting a custom PyTorch operator
189
# ------------------------------------
190
#
191
# In this case, the operator is an operator that is user implemented and registered to PyTorch.
192
#
193
# In the following example, we would like to use a custom operator
194
# that takes one tensor input, and returns one output. The operator adds
195
# the input to itself, and returns the rounded result.
196
#
197
# Firstly, we assume the custom operator is implemented and registered with ``torch.library.custom_op()``.
198
# You can refer to `Creating new custom ops in Python <https://pytorch.org/docs/stable/library.html#torch.library.custom_op>`_
199
# for a detailed guide on how to create custom operators.
200
201
202
# Define and use the operator in PyTorch
203
@torch.library.custom_op("mylibrary::add_and_round_op", mutates_args=())
204
def add_and_round_op(input: torch.Tensor) -> torch.Tensor:
205
return torch.round(input + input)
206
207
208
@add_and_round_op.register_fake
209
def _add_and_round_op_fake(tensor_x):
210
return torch.empty_like(tensor_x)
211
212
213
class AddAndRoundModel(torch.nn.Module):
214
def forward(self, input):
215
return add_and_round_op(input)
216
217
218
# Implement the custom operator in ONNX using ONNX Script
219
def onnx_add_and_round(input):
220
return op.Round(op.Add(input, input))
221
222
223
onnx_program = torch.onnx.export(
224
AddAndRoundModel().eval(),
225
(x,),
226
dynamo=True,
227
custom_translation_table={
228
torch.ops.mylibrary.add_and_round_op.default: onnx_add_and_round,
229
},
230
)
231
232
# Optimize the ONNX graph to remove redundant nodes
233
onnx_program.optimize()
234
print(onnx_program)
235
236
######################################################################
237
# The translation is using our custom implementation to translate the ``torch.ops.mylibrary.add_and_round_op.default``
238
# operator in the :class:`torch.export.ExportedProgram`` to the ONNX operator ``Add`` and ``Round``.
239
#
240
241
######################################################################
242
# Finally we verify the results.
243
244
result = onnx_program(x)[0]
245
torch.testing.assert_close(result, add_and_round_op(x))
246
247
######################################################################
248
# Conclusion
249
# ----------
250
#
251
# Congratulations! In this tutorial, we explored the ``custom_translation_table`` option and
252
# discovered how to create custom implementations for unsupported or existing PyTorch operators
253
# using ONNX Script.
254
#
255
# Finally, we leveraged ONNX Runtime to execute the model and compare the results with PyTorch,
256
# providing us with a comprehensive understanding of handling unsupported
257
# operators in the ONNX ecosystem.
258
#
259
# Further reading
260
# ---------------
261
#
262
# The list below refers to tutorials that ranges from basic examples to advanced scenarios,
263
# not necessarily in the order they are listed.
264
# Feel free to jump directly to specific topics of your interest or
265
# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
266
#
267
# .. include:: /beginner_source/onnx/onnx_toc.txt
268
#
269
# .. toctree::
270
# :hidden:
271
#
272
273