Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/python_custom_ops.py
1686 views
1
# -*- coding: utf-8 -*-
2
3
"""
4
.. _python-custom-ops-tutorial:
5
6
Custom Python Operators
7
=======================
8
9
.. grid:: 2
10
11
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
12
:class-card: card-prerequisites
13
14
* How to integrate custom operators written in Python with PyTorch
15
* How to test custom operators using ``torch.library.opcheck``
16
17
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
18
:class-card: card-prerequisites
19
20
* PyTorch 2.4 or later
21
22
PyTorch offers a large library of operators that work on Tensors (e.g.
23
``torch.add``, ``torch.sum``, etc). However, you might wish to use a new customized
24
operator with PyTorch, perhaps written by a third-party library. This tutorial
25
shows how to wrap Python functions so that they behave like PyTorch native
26
operators. Reasons why you may wish to create a custom operator in PyTorch include:
27
28
- Treating an arbitrary Python function as an opaque callable with respect
29
to ``torch.compile`` (that is, prevent ``torch.compile`` from tracing
30
into the function).
31
- Adding training support to an arbitrary Python function
32
33
Use :func:`torch.library.custom_op` to create Python custom operators.
34
Use the C++ ``TORCH_LIBRARY`` APIs to create C++ custom operators (these
35
work in Python-less environments).
36
See the `Custom Operators Landing Page <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html>`_
37
for more details.
38
39
Please note that if your operation can be expressed as a composition of
40
existing PyTorch operators, then there is usually no need to use the custom operator
41
API -- everything (for example ``torch.compile``, training support) should
42
just work.
43
"""
44
######################################################################
45
# Example: Wrapping PIL's crop into a custom operator
46
# ------------------------------------
47
# Let's say that we are using PIL's ``crop`` operation.
48
49
import torch
50
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
51
import PIL
52
import IPython
53
import matplotlib.pyplot as plt
54
55
def crop(pic, box):
56
img = to_pil_image(pic.cpu())
57
cropped_img = img.crop(box)
58
return pil_to_tensor(cropped_img).to(pic.device) / 255.
59
60
def display(img):
61
plt.imshow(img.numpy().transpose((1, 2, 0)))
62
63
img = torch.ones(3, 64, 64)
64
img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)
65
display(img)
66
67
######################################################################
68
69
cropped_img = crop(img, (10, 10, 50, 50))
70
display(cropped_img)
71
72
######################################################################
73
# ``crop`` is not handled effectively out-of-the-box by
74
# ``torch.compile``: ``torch.compile`` induces a
75
# `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_
76
# on functions it is unable to handle and graph breaks are bad for performance.
77
# The following code demonstrates this by raising an error
78
# (``torch.compile`` with ``fullgraph=True`` raises an error if a
79
# graph break occurs).
80
81
@torch.compile(fullgraph=True)
82
def f(img):
83
return crop(img, (10, 10, 50, 50))
84
85
# The following raises an error. Uncomment the line to see it.
86
# cropped_img = f(img)
87
88
######################################################################
89
# In order to black-box ``crop`` for use with ``torch.compile``, we need to
90
# do two things:
91
#
92
# 1. wrap the function into a PyTorch custom operator.
93
# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator.
94
# Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage),
95
# this function should return dummy Tensors of your choice with the correct
96
# Tensor metadata (shape/strides/``dtype``/device).
97
98
99
from typing import Sequence
100
101
# Use torch.library.custom_op to define a new custom operator.
102
# If your operator mutates any input Tensors, their names must be specified
103
# in the ``mutates_args`` argument.
104
@torch.library.custom_op("mylib::crop", mutates_args=())
105
def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:
106
img = to_pil_image(pic.cpu())
107
cropped_img = img.crop(box)
108
return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)
109
110
# Use register_fake to add a ``FakeTensor`` kernel for the operator
111
@crop.register_fake
112
def _(pic, box):
113
channels = pic.shape[0]
114
x0, y0, x1, y1 = box
115
result = pic.new_empty(y1 - y0, x1 - x0, channels).permute(2, 0, 1)
116
# The result should have the same metadata (shape/strides/``dtype``/device)
117
# as running the ``crop`` function above.
118
return result
119
120
######################################################################
121
# After this, ``crop`` now works without graph breaks:
122
123
@torch.compile(fullgraph=True)
124
def f(img):
125
return crop(img, (10, 10, 50, 50))
126
127
cropped_img = f(img)
128
display(img)
129
130
######################################################################
131
132
display(cropped_img)
133
134
######################################################################
135
# Adding training support for crop
136
# --------------------------------
137
# Use ``torch.library.register_autograd`` to add training support for an operator.
138
# Prefer this over directly using ``torch.autograd.Function``; some compositions of
139
# ``autograd.Function`` with PyTorch operator registration APIs can lead to (and
140
# has led to) silent incorrectness when composed with ``torch.compile``.
141
#
142
# If you don't need training support, there is no need to use
143
# ``torch.library.register_autograd``.
144
# If you end up training with a ``custom_op`` that doesn't have an autograd
145
# registration, we'll raise an error message.
146
#
147
# The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the
148
# derivation as an exercise to the reader). Let's first wrap ``paste`` into a
149
# custom operator:
150
151
@torch.library.custom_op("mylib::paste", mutates_args=())
152
def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:
153
assert im1.device == im2.device
154
assert im1.dtype == im2.dtype
155
im1_pil = to_pil_image(im1.cpu())
156
im2_pil = to_pil_image(im2.cpu())
157
PIL.Image.Image.paste(im1_pil, im2_pil, coord)
158
return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype)
159
160
@paste.register_fake
161
def _(im1, im2, coord):
162
assert im1.device == im2.device
163
assert im1.dtype == im2.dtype
164
return torch.empty_like(im1)
165
166
######################################################################
167
# And now let's use ``register_autograd`` to specify the gradient formula for ``crop``:
168
169
def backward(ctx, grad_output):
170
grad_input = grad_output.new_zeros(ctx.pic_shape)
171
grad_input = paste(grad_input, grad_output, ctx.coords)
172
return grad_input, None
173
174
def setup_context(ctx, inputs, output):
175
pic, box = inputs
176
ctx.coords = box[:2]
177
ctx.pic_shape = pic.shape
178
179
crop.register_autograd(backward, setup_context=setup_context)
180
181
######################################################################
182
# Note that the backward must be a composition of PyTorch-understood operators,
183
# which is why we wrapped paste into a custom operator instead of directly using
184
# PIL's paste.
185
186
img = img.requires_grad_()
187
result = crop(img, (10, 10, 50, 50))
188
result.sum().backward()
189
display(img.grad)
190
191
######################################################################
192
# This is the correct gradient, with 1s (white) in the cropped region and 0s
193
# (black) in the unused region.
194
195
######################################################################
196
# Testing Python Custom operators
197
# -------------------------------
198
# Use ``torch.library.opcheck`` to test that the custom operator was registered
199
# correctly. This does not test that the gradients are mathematically correct;
200
# please write separate tests for that (either manual ones or ``torch.autograd.gradcheck``).
201
#
202
# To use ``opcheck``, pass it a set of example inputs to test against. If your
203
# operator supports training, then the examples should include Tensors that
204
# require grad. If your operator supports multiple devices, then the examples
205
# should include Tensors from each device.
206
207
examples = [
208
[torch.randn(3, 64, 64), [0, 0, 10, 10]],
209
[torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]],
210
[torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]],
211
[torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]],
212
]
213
214
for example in examples:
215
torch.library.opcheck(crop, example)
216
217
######################################################################
218
# Mutable Python Custom operators
219
# -------------------------------
220
# You can also wrap a Python function that mutates its inputs into a custom
221
# operator.
222
# Functions that mutate inputs are common because that is how many low-level
223
# kernels are written; for example, a kernel that computes ``sin`` may take in
224
# the input and an output tensor and write ``input.sin()`` to the output tensor.
225
#
226
# We'll use ``numpy.sin`` to demonstrate an example of a mutable Python
227
# custom operator.
228
229
import numpy as np
230
231
@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu")
232
def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:
233
assert input.device == output.device
234
assert input.device.type == "cpu"
235
input_np = input.numpy()
236
output_np = output.numpy()
237
np.sin(input_np, out=output_np)
238
239
######################################################################
240
# Because the operator doesn't return anything, there is no need to register
241
# a ``FakeTensor`` kernel (meta kernel) to get it to work with ``torch.compile``.
242
243
@torch.compile(fullgraph=True)
244
def f(x):
245
out = torch.empty(3)
246
numpy_sin(x, out)
247
return out
248
249
x = torch.randn(3)
250
y = f(x)
251
assert torch.allclose(y, x.sin())
252
253
######################################################################
254
# And here's an ``opcheck`` run telling us that we did indeed register the operator correctly.
255
# ``opcheck`` would error out if we forgot to add the output to ``mutates_args``, for example.
256
257
example_inputs = [
258
[torch.randn(3), torch.empty(3)],
259
[torch.randn(0, 3), torch.empty(0, 3)],
260
[torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)],
261
]
262
263
for example in example_inputs:
264
torch.library.opcheck(numpy_sin, example)
265
266
######################################################################
267
# Conclusion
268
# ----------
269
# In this tutorial, we learned how to use ``torch.library.custom_op`` to
270
# create a custom operator in Python that works with PyTorch subsystems
271
# such as ``torch.compile`` and autograd.
272
#
273
# This tutorial provides a basic introduction to custom operators.
274
# For more detailed information, see:
275
#
276
# - `the torch.library documentation <https://pytorch.org/docs/stable/library.html>`_
277
# - `the Custom Operators Manual <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual>`_
278
#
279
280