# -*- coding: utf-8 -*-12"""3.. _python-custom-ops-tutorial:45Custom Python Operators6=======================78.. grid:: 2910.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn11:class-card: card-prerequisites1213* How to integrate custom operators written in Python with PyTorch14* How to test custom operators using ``torch.library.opcheck``1516.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites17:class-card: card-prerequisites1819* PyTorch 2.4 or later2021PyTorch offers a large library of operators that work on Tensors (e.g.22``torch.add``, ``torch.sum``, etc). However, you might wish to use a new customized23operator with PyTorch, perhaps written by a third-party library. This tutorial24shows how to wrap Python functions so that they behave like PyTorch native25operators. Reasons why you may wish to create a custom operator in PyTorch include:2627- Treating an arbitrary Python function as an opaque callable with respect28to ``torch.compile`` (that is, prevent ``torch.compile`` from tracing29into the function).30- Adding training support to an arbitrary Python function3132Use :func:`torch.library.custom_op` to create Python custom operators.33Use the C++ ``TORCH_LIBRARY`` APIs to create C++ custom operators (these34work in Python-less environments).35See the `Custom Operators Landing Page <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html>`_36for more details.3738Please note that if your operation can be expressed as a composition of39existing PyTorch operators, then there is usually no need to use the custom operator40API -- everything (for example ``torch.compile``, training support) should41just work.42"""43######################################################################44# Example: Wrapping PIL's crop into a custom operator45# ------------------------------------46# Let's say that we are using PIL's ``crop`` operation.4748import torch49from torchvision.transforms.functional import to_pil_image, pil_to_tensor50import PIL51import IPython52import matplotlib.pyplot as plt5354def crop(pic, box):55img = to_pil_image(pic.cpu())56cropped_img = img.crop(box)57return pil_to_tensor(cropped_img).to(pic.device) / 255.5859def display(img):60plt.imshow(img.numpy().transpose((1, 2, 0)))6162img = torch.ones(3, 64, 64)63img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)64display(img)6566######################################################################6768cropped_img = crop(img, (10, 10, 50, 50))69display(cropped_img)7071######################################################################72# ``crop`` is not handled effectively out-of-the-box by73# ``torch.compile``: ``torch.compile`` induces a74# `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_75# on functions it is unable to handle and graph breaks are bad for performance.76# The following code demonstrates this by raising an error77# (``torch.compile`` with ``fullgraph=True`` raises an error if a78# graph break occurs).7980@torch.compile(fullgraph=True)81def f(img):82return crop(img, (10, 10, 50, 50))8384# The following raises an error. Uncomment the line to see it.85# cropped_img = f(img)8687######################################################################88# In order to black-box ``crop`` for use with ``torch.compile``, we need to89# do two things:90#91# 1. wrap the function into a PyTorch custom operator.92# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator.93# Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage),94# this function should return dummy Tensors of your choice with the correct95# Tensor metadata (shape/strides/``dtype``/device).969798from typing import Sequence99100# Use torch.library.custom_op to define a new custom operator.101# If your operator mutates any input Tensors, their names must be specified102# in the ``mutates_args`` argument.103@torch.library.custom_op("mylib::crop", mutates_args=())104def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:105img = to_pil_image(pic.cpu())106cropped_img = img.crop(box)107return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)108109# Use register_fake to add a ``FakeTensor`` kernel for the operator110@crop.register_fake111def _(pic, box):112channels = pic.shape[0]113x0, y0, x1, y1 = box114result = pic.new_empty(y1 - y0, x1 - x0, channels).permute(2, 0, 1)115# The result should have the same metadata (shape/strides/``dtype``/device)116# as running the ``crop`` function above.117return result118119######################################################################120# After this, ``crop`` now works without graph breaks:121122@torch.compile(fullgraph=True)123def f(img):124return crop(img, (10, 10, 50, 50))125126cropped_img = f(img)127display(img)128129######################################################################130131display(cropped_img)132133######################################################################134# Adding training support for crop135# --------------------------------136# Use ``torch.library.register_autograd`` to add training support for an operator.137# Prefer this over directly using ``torch.autograd.Function``; some compositions of138# ``autograd.Function`` with PyTorch operator registration APIs can lead to (and139# has led to) silent incorrectness when composed with ``torch.compile``.140#141# If you don't need training support, there is no need to use142# ``torch.library.register_autograd``.143# If you end up training with a ``custom_op`` that doesn't have an autograd144# registration, we'll raise an error message.145#146# The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the147# derivation as an exercise to the reader). Let's first wrap ``paste`` into a148# custom operator:149150@torch.library.custom_op("mylib::paste", mutates_args=())151def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:152assert im1.device == im2.device153assert im1.dtype == im2.dtype154im1_pil = to_pil_image(im1.cpu())155im2_pil = to_pil_image(im2.cpu())156PIL.Image.Image.paste(im1_pil, im2_pil, coord)157return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype)158159@paste.register_fake160def _(im1, im2, coord):161assert im1.device == im2.device162assert im1.dtype == im2.dtype163return torch.empty_like(im1)164165######################################################################166# And now let's use ``register_autograd`` to specify the gradient formula for ``crop``:167168def backward(ctx, grad_output):169grad_input = grad_output.new_zeros(ctx.pic_shape)170grad_input = paste(grad_input, grad_output, ctx.coords)171return grad_input, None172173def setup_context(ctx, inputs, output):174pic, box = inputs175ctx.coords = box[:2]176ctx.pic_shape = pic.shape177178crop.register_autograd(backward, setup_context=setup_context)179180######################################################################181# Note that the backward must be a composition of PyTorch-understood operators,182# which is why we wrapped paste into a custom operator instead of directly using183# PIL's paste.184185img = img.requires_grad_()186result = crop(img, (10, 10, 50, 50))187result.sum().backward()188display(img.grad)189190######################################################################191# This is the correct gradient, with 1s (white) in the cropped region and 0s192# (black) in the unused region.193194######################################################################195# Testing Python Custom operators196# -------------------------------197# Use ``torch.library.opcheck`` to test that the custom operator was registered198# correctly. This does not test that the gradients are mathematically correct;199# please write separate tests for that (either manual ones or ``torch.autograd.gradcheck``).200#201# To use ``opcheck``, pass it a set of example inputs to test against. If your202# operator supports training, then the examples should include Tensors that203# require grad. If your operator supports multiple devices, then the examples204# should include Tensors from each device.205206examples = [207[torch.randn(3, 64, 64), [0, 0, 10, 10]],208[torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]],209[torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]],210[torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]],211]212213for example in examples:214torch.library.opcheck(crop, example)215216######################################################################217# Mutable Python Custom operators218# -------------------------------219# You can also wrap a Python function that mutates its inputs into a custom220# operator.221# Functions that mutate inputs are common because that is how many low-level222# kernels are written; for example, a kernel that computes ``sin`` may take in223# the input and an output tensor and write ``input.sin()`` to the output tensor.224#225# We'll use ``numpy.sin`` to demonstrate an example of a mutable Python226# custom operator.227228import numpy as np229230@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu")231def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:232assert input.device == output.device233assert input.device.type == "cpu"234input_np = input.numpy()235output_np = output.numpy()236np.sin(input_np, out=output_np)237238######################################################################239# Because the operator doesn't return anything, there is no need to register240# a ``FakeTensor`` kernel (meta kernel) to get it to work with ``torch.compile``.241242@torch.compile(fullgraph=True)243def f(x):244out = torch.empty(3)245numpy_sin(x, out)246return out247248x = torch.randn(3)249y = f(x)250assert torch.allclose(y, x.sin())251252######################################################################253# And here's an ``opcheck`` run telling us that we did indeed register the operator correctly.254# ``opcheck`` would error out if we forgot to add the output to ``mutates_args``, for example.255256example_inputs = [257[torch.randn(3), torch.empty(3)],258[torch.randn(0, 3), torch.empty(0, 3)],259[torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)],260]261262for example in example_inputs:263torch.library.opcheck(numpy_sin, example)264265######################################################################266# Conclusion267# ----------268# In this tutorial, we learned how to use ``torch.library.custom_op`` to269# create a custom operator in Python that works with PyTorch subsystems270# such as ``torch.compile`` and autograd.271#272# This tutorial provides a basic introduction to custom operators.273# For more detailed information, see:274#275# - `the torch.library documentation <https://pytorch.org/docs/stable/library.html>`_276# - `the Custom Operators Manual <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual>`_277#278279280