Path: blob/main/recipes_source/recipes/tuning_guide.py
1694 views
"""1Performance Tuning Guide2*************************3**Author**: `Szymon Migacz <https://github.com/szmigacz>`_45Performance Tuning Guide is a set of optimizations and best practices which can6accelerate training and inference of deep learning models in PyTorch. Presented7techniques often can be implemented by changing only a few lines of code and can8be applied to a wide range of deep learning models across all domains.910.. grid:: 21112.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn13:class-card: card-prerequisites1415* General optimization techniques for PyTorch models16* CPU-specific performance optimizations17* GPU acceleration strategies18* Distributed training optimizations1920.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites21:class-card: card-prerequisites2223* PyTorch 2.0 or later24* Python 3.8 or later25* CUDA-capable GPU (recommended for GPU optimizations)26* Linux, macOS, or Windows operating system2728Overview29--------3031Performance optimization is crucial for efficient deep learning model training and inference.32This tutorial covers a comprehensive set of techniques to accelerate PyTorch workloads across33different hardware configurations and use cases.3435General optimizations36---------------------37"""3839import torch40import torchvision4142###############################################################################43# Enable asynchronous data loading and augmentation44# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~45# `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_46# supports asynchronous data loading and data augmentation in separate worker47# subprocesses. The default setting for ``DataLoader`` is ``num_workers=0``,48# which means that the data loading is synchronous and done in the main process.49# As a result the main training process has to wait for the data to be available50# to continue the execution.51#52# Setting ``num_workers > 0`` enables asynchronous data loading and overlap53# between the training and data loading. ``num_workers`` should be tuned54# depending on the workload, CPU, GPU, and location of training data.55#56# ``DataLoader`` accepts ``pin_memory`` argument, which defaults to ``False``.57# When using a GPU it's better to set ``pin_memory=True``, this instructs58# ``DataLoader`` to use pinned memory and enables faster and asynchronous memory59# copy from the host to the GPU.6061###############################################################################62# Disable gradient calculation for validation or inference63# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~64# PyTorch saves intermediate buffers from all operations which involve tensors65# that require gradients. Typically gradients aren't needed for validation or66# inference.67# `torch.no_grad() <https://pytorch.org/docs/stable/generated/torch.no_grad.html#torch.no_grad>`_68# context manager can be applied to disable gradient calculation within a69# specified block of code, this accelerates execution and reduces the amount of70# required memory.71# `torch.no_grad() <https://pytorch.org/docs/stable/generated/torch.no_grad.html#torch.no_grad>`_72# can also be used as a function decorator.7374###############################################################################75# Disable bias for convolutions directly followed by a batch norm76# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~77# `torch.nn.Conv2d() <https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d>`_78# has ``bias`` parameter which defaults to ``True`` (the same is true for79# `Conv1d <https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d>`_80# and81# `Conv3d <https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#torch.nn.Conv3d>`_82# ).83#84# If a ``nn.Conv2d`` layer is directly followed by a ``nn.BatchNorm2d`` layer,85# then the bias in the convolution is not needed, instead use86# ``nn.Conv2d(..., bias=False, ....)``. Bias is not needed because in the first87# step ``BatchNorm`` subtracts the mean, which effectively cancels out the88# effect of bias.89#90# This is also applicable to 1d and 3d convolutions as long as ``BatchNorm`` (or91# other normalization layer) normalizes on the same dimension as convolution's92# bias.93#94# Models available from `torchvision <https://github.com/pytorch/vision>`_95# already implement this optimization.9697###############################################################################98# Use parameter.grad = None instead of model.zero_grad() or optimizer.zero_grad()99# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~100# Instead of calling:101model.zero_grad()102# or103optimizer.zero_grad()104105###############################################################################106# to zero out gradients, use the following method instead:107108for param in model.parameters():109param.grad = None110111###############################################################################112# The second code snippet does not zero the memory of each individual parameter,113# also the subsequent backward pass uses assignment instead of addition to store114# gradients, this reduces the number of memory operations.115#116# Setting gradient to ``None`` has a slightly different numerical behavior than117# setting it to zero, for more details refer to the118# `documentation <https://pytorch.org/docs/master/optim.html#torch.optim.Optimizer.zero_grad>`_.119#120# Alternatively, call ``model`` or121# ``optimizer.zero_grad(set_to_none=True)``.122123###############################################################################124# Fuse operations125# ~~~~~~~~~~~~~~~~~~~~~~~~~126# Pointwise operations such as elementwise addition, multiplication, and math127# functions like `sin()`, `cos()`, `sigmoid()`, etc., can be combined into a128# single kernel. This fusion helps reduce memory access and kernel launch times.129# Typically, pointwise operations are memory-bound; PyTorch eager-mode initiates130# a separate kernel for each operation, which involves loading data from memory,131# executing the operation (often not the most time-consuming step), and writing132# the results back to memory.133#134# By using a fused operator, only one kernel is launched for multiple pointwise135# operations, and data is loaded and stored just once. This efficiency is136# particularly beneficial for activation functions, optimizers, and custom RNN cells etc.137#138# PyTorch 2 introduces a compile-mode facilitated by TorchInductor, an underlying compiler139# that automatically fuses kernels. TorchInductor extends its capabilities beyond simple140# element-wise operations, enabling advanced fusion of eligible pointwise and reduction141# operations for optimized performance.142#143# In the simplest case fusion can be enabled by applying144# `torch.compile <https://pytorch.org/docs/stable/generated/torch.compile.html>`_145# decorator to the function definition, for example:146147@torch.compile148def gelu(x):149return x * 0.5 * (1.0 + torch.erf(x / 1.41421))150151###############################################################################152# Refer to153# `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_154# for more advanced use cases.155156###############################################################################157# Enable channels_last memory format for computer vision models158# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~159# PyTorch supports ``channels_last`` memory format for160# convolutional networks. This format is meant to be used in conjunction with161# `AMP <https://pytorch.org/docs/stable/amp.html>`_ to further accelerate162# convolutional neural networks with163# `Tensor Cores <https://www.nvidia.com/en-us/data-center/tensor-cores/>`_.164#165# Support for ``channels_last`` is experimental, but it's expected to work for166# standard computer vision models (e.g. ResNet-50, SSD). To convert models to167# ``channels_last`` format follow168# `Channels Last Memory Format Tutorial <https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html>`_.169# The tutorial includes a section on170# `converting existing models <https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html#converting-existing-models>`_.171172###############################################################################173# Checkpoint intermediate buffers174# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~175# Buffer checkpointing is a technique to mitigate the memory capacity burden of176# model training. Instead of storing inputs of all layers to compute upstream177# gradients in backward propagation, it stores the inputs of a few layers and178# the others are recomputed during backward pass. The reduced memory179# requirements enables increasing the batch size that can improve utilization.180#181# Checkpointing targets should be selected carefully. The best is not to store182# large layer outputs that have small re-computation cost. The example target183# layers are activation functions (e.g. ``ReLU``, ``Sigmoid``, ``Tanh``),184# up/down sampling and matrix-vector operations with small accumulation depth.185#186# PyTorch supports a native187# `torch.utils.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`_188# API to automatically perform checkpointing and recomputation.189190###############################################################################191# Disable debugging APIs192# ~~~~~~~~~~~~~~~~~~~~~~193# Many PyTorch APIs are intended for debugging and should be disabled for194# regular training runs:195#196# * anomaly detection:197# `torch.autograd.detect_anomaly <https://pytorch.org/docs/stable/autograd.html#torch.autograd.detect_anomaly>`_198# or199# `torch.autograd.set_detect_anomaly(True) <https://pytorch.org/docs/stable/autograd.html#torch.autograd.set_detect_anomaly>`_200# * profiler related:201# `torch.autograd.profiler.emit_nvtx <https://pytorch.org/docs/stable/autograd.html#torch.autograd.profiler.emit_nvtx>`_,202# `torch.autograd.profiler.profile <https://pytorch.org/docs/stable/autograd.html#torch.autograd.profiler.profile>`_203# * autograd ``gradcheck``:204# `torch.autograd.gradcheck <https://pytorch.org/docs/stable/autograd.html#torch.autograd.gradcheck>`_205# or206# `torch.autograd.gradgradcheck <https://pytorch.org/docs/stable/autograd.html#torch.autograd.gradgradcheck>`_207#208209###############################################################################210# CPU specific optimizations211# --------------------------212213###############################################################################214# Utilize Non-Uniform Memory Access (NUMA) Controls215# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~216# NUMA or non-uniform memory access is a memory layout design used in data center machines meant to take advantage of locality of memory in multi-socket machines with multiple memory controllers and blocks. Generally speaking, all deep learning workloads, training or inference, get better performance without accessing hardware resources across NUMA nodes. Thus, inference can be run with multiple instances, each instance runs on one socket, to raise throughput. For training tasks on single node, distributed training is recommended to make each training process run on one socket.217#218# In general cases the following command executes a PyTorch script on cores on the Nth node only, and avoids cross-socket memory access to reduce memory access overhead.219#220# .. code-block:: sh221#222# numactl --cpunodebind=N --membind=N python <pytorch_script>223224###############################################################################225# More detailed descriptions can be found `here <https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html>`_.226227###############################################################################228# Utilize OpenMP229# ~~~~~~~~~~~~~~230# OpenMP is utilized to bring better performance for parallel computation tasks.231# ``OMP_NUM_THREADS`` is the easiest switch that can be used to accelerate computations. It determines number of threads used for OpenMP computations.232# CPU affinity setting controls how workloads are distributed over multiple cores. It affects communication overhead, cache line invalidation overhead, or page thrashing, thus proper setting of CPU affinity brings performance benefits. ``GOMP_CPU_AFFINITY`` or ``KMP_AFFINITY`` determines how to bind OpenMP* threads to physical processing units. Detailed information can be found `here <https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html>`_.233234###############################################################################235# With the following command, PyTorch run the task on N OpenMP threads.236#237# .. code-block:: sh238#239# export OMP_NUM_THREADS=N240241###############################################################################242# Typically, the following environment variables are used to set for CPU affinity with GNU OpenMP implementation. ``OMP_PROC_BIND`` specifies whether threads may be moved between processors. Setting it to CLOSE keeps OpenMP threads close to the primary thread in contiguous place partitions. ``OMP_SCHEDULE`` determines how OpenMP threads are scheduled. ``GOMP_CPU_AFFINITY`` binds threads to specific CPUs.243# An important tuning parameter is core pinning which prevent the threads of migrating between multiple CPUs, enhancing data location and minimizing inter core communication.244#245# .. code-block:: sh246#247# export OMP_SCHEDULE=STATIC248# export OMP_PROC_BIND=CLOSE249# export GOMP_CPU_AFFINITY="N-M"250251###############################################################################252# Intel OpenMP Runtime Library (``libiomp``)253# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~254# By default, PyTorch uses GNU OpenMP (GNU ``libgomp``) for parallel computation. On Intel platforms, Intel OpenMP Runtime Library (``libiomp``) provides OpenMP API specification support. It sometimes brings more performance benefits compared to ``libgomp``. Utilizing environment variable ``LD_PRELOAD`` can switch OpenMP library to ``libiomp``:255#256# .. code-block:: sh257#258# export LD_PRELOAD=<path>/libiomp5.so:$LD_PRELOAD259260###############################################################################261# Similar to CPU affinity settings in GNU OpenMP, environment variables are provided in ``libiomp`` to control CPU affinity settings.262# ``KMP_AFFINITY`` binds OpenMP threads to physical processing units. ``KMP_BLOCKTIME`` sets the time, in milliseconds, that a thread should wait, after completing the execution of a parallel region, before sleeping. In most cases, setting ``KMP_BLOCKTIME`` to 1 or 0 yields good performances.263# The following commands show a common settings with Intel OpenMP Runtime Library.264#265# .. code-block:: sh266#267# export KMP_AFFINITY=granularity=fine,compact,1,0268# export KMP_BLOCKTIME=1269270###############################################################################271# Switch Memory allocator272# ~~~~~~~~~~~~~~~~~~~~~~~273# For deep learning workloads, ``Jemalloc`` or ``TCMalloc`` can get better performance by reusing memory as much as possible than default ``malloc`` function. `Jemalloc <https://github.com/jemalloc/jemalloc>`_ is a general purpose ``malloc`` implementation that emphasizes fragmentation avoidance and scalable concurrency support. `TCMalloc <https://google.github.io/tcmalloc/overview.html>`_ also features a couple of optimizations to speed up program executions. One of them is holding memory in caches to speed up access of commonly-used objects. Holding such caches even after deallocation also helps avoid costly system calls if such memory is later re-allocated.274# Use environment variable ``LD_PRELOAD`` to take advantage of one of them.275#276# .. code-block:: sh277#278# export LD_PRELOAD=<jemalloc.so/tcmalloc.so>:$LD_PRELOAD279280281###############################################################################282# Train a model on CPU with PyTorch ``DistributedDataParallel``(DDP) functionality283# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~284# For small scale models or memory-bound models, such as DLRM, training on CPU is also a good choice. On a machine with multiple sockets, distributed training brings a high-efficient hardware resource usage to accelerate the training process. `Torch-ccl <https://github.com/intel/torch-ccl>`_, optimized with Intel(R) ``oneCCL`` (collective communications library) for efficient distributed deep learning training implementing such collectives like ``allreduce``, ``allgather``, ``alltoall``, implements PyTorch C10D ``ProcessGroup`` API and can be dynamically loaded as external ``ProcessGroup``. Upon optimizations implemented in PyTorch DDP module, ``torch-ccl`` accelerates communication operations. Beside the optimizations made to communication kernels, ``torch-ccl`` also features simultaneous computation-communication functionality.285286###############################################################################287# GPU specific optimizations288# --------------------------289290###############################################################################291# Enable Tensor cores292# ~~~~~~~~~~~~~~~~~~~~~~~293# Tensor cores are specialized hardware designed to compute matrix-matrix multiplication294# operations, primarily utilized in deep learning and AI workloads. Tensor cores have295# specific precision requirements which can be adjusted manually or via the Automatic296# Mixed Precision API.297#298# In particular, tensor operations take advantage of lower precision workloads.299# Which can be controlled via ``torch.set_float32_matmul_precision``.300# The default format is set to 'highest,' which utilizes the tensor data type.301# However, PyTorch offers alternative precision settings: 'high' and 'medium.'302# These options prioritize computational speed over numerical precision."303304###############################################################################305# Use CUDA Graphs306# ~~~~~~~~~~~~~~~~~~~~~~~307# At the time of using a GPU, work first must be launched from the CPU and308# in some cases the context switch between CPU and GPU can lead to bad resource309# utilization. CUDA graphs are a way to keep computation within the GPU without310# paying the extra cost of kernel launches and host synchronization.311312# It can be enabled using313torch.compile(m, "reduce-overhead")314# or315torch.compile(m, "max-autotune")316317###############################################################################318# Support for CUDA graph is in development, and its usage can incur in increased319# device memory consumption and some models might not compile.320321###############################################################################322# Enable cuDNN auto-tuner323# ~~~~~~~~~~~~~~~~~~~~~~~324# `NVIDIA cuDNN <https://developer.nvidia.com/cudnn>`_ supports many algorithms325# to compute a convolution. Autotuner runs a short benchmark and selects the326# kernel with the best performance on a given hardware for a given input size.327#328# For convolutional networks (other types currently not supported), enable cuDNN329# autotuner before launching the training loop by setting:330331torch.backends.cudnn.benchmark = True332###############################################################################333#334# * the auto-tuner decisions may be non-deterministic; different algorithm may335# be selected for different runs. For more details see336# `PyTorch: Reproducibility <https://pytorch.org/docs/stable/notes/randomness.html?highlight=determinism>`_337# * in some rare cases, such as with highly variable input sizes, it's better338# to run convolutional networks with autotuner disabled to avoid the overhead339# associated with algorithm selection for each input size.340#341342###############################################################################343# Avoid unnecessary CPU-GPU synchronization344# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~345# Avoid unnecessary synchronizations, to let the CPU run ahead of the346# accelerator as much as possible to make sure that the accelerator work queue347# contains many operations.348#349# When possible, avoid operations which require synchronizations, for example:350#351# * ``print(cuda_tensor)``352# * ``cuda_tensor.item()``353# * memory copies: ``tensor.cuda()``, ``cuda_tensor.cpu()`` and equivalent354# ``tensor.to(device)`` calls355# * ``cuda_tensor.nonzero()``356# * python control flow which depends on results of operations performed on CUDA357# tensors e.g. ``if (cuda_tensor != 0).all()``358#359360###############################################################################361# Create tensors directly on the target device362# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~363# Instead of calling ``torch.rand(size).cuda()`` to generate a random tensor,364# produce the output directly on the target device:365# ``torch.rand(size, device='cuda')``.366#367# This is applicable to all functions which create new tensors and accept368# ``device`` argument:369# `torch.rand() <https://pytorch.org/docs/stable/generated/torch.rand.html#torch.rand>`_,370# `torch.zeros() <https://pytorch.org/docs/stable/generated/torch.zeros.html#torch.zeros>`_,371# `torch.full() <https://pytorch.org/docs/stable/generated/torch.full.html#torch.full>`_372# and similar.373374###############################################################################375# Use mixed precision and AMP376# ~~~~~~~~~~~~~~~~~~~~~~~~~~~377# Mixed precision leverages378# `Tensor Cores <https://www.nvidia.com/en-us/data-center/tensor-cores/>`_379# and offers up to 3x overall speedup on Volta and newer GPU architectures. To380# use Tensor Cores AMP should be enabled and matrix/tensor dimensions should381# satisfy requirements for calling kernels that use Tensor Cores.382#383# To use Tensor Cores:384#385# * set sizes to multiples of 8 (to map onto dimensions of Tensor Cores)386#387# * see388# `Deep Learning Performance Documentation389# <https://docs.nvidia.com/deeplearning/performance/index.html#optimizing-performance>`_390# for more details and guidelines specific to layer type391# * if layer size is derived from other parameters rather than fixed, it can392# still be explicitly padded e.g. vocabulary size in NLP models393#394# * enable AMP395#396# * Introduction to Mixed Precision Training and AMP:397# `slides <https://nvlabs.github.io/eccv2020-mixed-precision-tutorial/files/dusan_stosic-training-neural-networks-with-tensor-cores.pdf>`_398# * native PyTorch AMP is available:399# `documentation <https://pytorch.org/docs/stable/amp.html>`_,400# `examples <https://pytorch.org/docs/stable/notes/amp_examples.html#amp-examples>`_,401# `tutorial <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_402#403#404405###############################################################################406# Preallocate memory in case of variable input length407# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~408# Models for speech recognition or for NLP are often trained on input tensors409# with variable sequence length. Variable length can be problematic for PyTorch410# caching allocator and can lead to reduced performance or to unexpected411# out-of-memory errors. If a batch with a short sequence length is followed by412# an another batch with longer sequence length, then PyTorch is forced to413# release intermediate buffers from previous iteration and to re-allocate new414# buffers. This process is time consuming and causes fragmentation in the415# caching allocator which may result in out-of-memory errors.416#417# A typical solution is to implement preallocation. It consists of the418# following steps:419#420# #. generate a (usually random) batch of inputs with maximum sequence length421# (either corresponding to max length in the training dataset or to some422# predefined threshold)423# #. execute a forward and a backward pass with the generated batch, do not424# execute an optimizer or a learning rate scheduler, this step preallocates425# buffers of maximum size, which can be reused in subsequent426# training iterations427# #. zero out gradients428# #. proceed to regular training429#430431###############################################################################432# Distributed optimizations433# -------------------------434435###############################################################################436# Use efficient data-parallel backend437# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~438# PyTorch has two ways to implement data-parallel training:439#440# * `torch.nn.DataParallel <https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html#torch.nn.DataParallel>`_441# * `torch.nn.parallel.DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_442#443# ``DistributedDataParallel`` offers much better performance and scaling to444# multiple-GPUs. For more information refer to the445# `relevant section of CUDA Best Practices <https://pytorch.org/docs/stable/notes/cuda.html#use-nn-parallel-distributeddataparallel-instead-of-multiprocessing-or-nn-dataparallel>`_446# from PyTorch documentation.447448###############################################################################449# Skip unnecessary all-reduce if training with ``DistributedDataParallel`` and gradient accumulation450# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~451# By default452# `torch.nn.parallel.DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_453# executes gradient all-reduce after every backward pass to compute the average454# gradient over all workers participating in the training. If training uses455# gradient accumulation over N steps, then all-reduce is not necessary after456# every training step, it's only required to perform all-reduce after the last457# call to backward, just before the execution of the optimizer.458#459# ``DistributedDataParallel`` provides460# `no_sync() <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync>`_461# context manager which disables gradient all-reduce for particular iteration.462# ``no_sync()`` should be applied to first ``N-1`` iterations of gradient463# accumulation, the last iteration should follow the default execution and464# perform the required gradient all-reduce.465466###############################################################################467# Match the order of layers in constructors and during the execution if using ``DistributedDataParallel(find_unused_parameters=True)``468# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~469# `torch.nn.parallel.DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_470# with ``find_unused_parameters=True`` uses the order of layers and parameters471# from model constructors to build buckets for ``DistributedDataParallel``472# gradient all-reduce. ``DistributedDataParallel`` overlaps all-reduce with the473# backward pass. All-reduce for a particular bucket is asynchronously triggered474# only when all gradients for parameters in a given bucket are available.475#476# To maximize the amount of overlap, the order in model constructors should477# roughly match the order during the execution. If the order doesn't match, then478# all-reduce for the entire bucket waits for the gradient which is the last to479# arrive, this may reduce the overlap between backward pass and all-reduce,480# all-reduce may end up being exposed, which slows down the training.481#482# ``DistributedDataParallel`` with ``find_unused_parameters=False`` (which is483# the default setting) relies on automatic bucket formation based on order of484# operations encountered during the backward pass. With485# ``find_unused_parameters=False`` it's not necessary to reorder layers or486# parameters to achieve optimal performance.487488###############################################################################489# Load-balance workload in a distributed setting490# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~491# Load imbalance typically may happen for models processing sequential data492# (speech recognition, translation, language models etc.). If one device493# receives a batch of data with sequence length longer than sequence lengths for494# the remaining devices, then all devices wait for the worker which finishes495# last. Backward pass functions as an implicit synchronization point in a496# distributed setting with497# `DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_498# backend.499#500# There are multiple ways to solve the load balancing problem. The core idea is501# to distribute workload over all workers as uniformly as possible within each502# global batch. For example Transformer solves imbalance by forming batches with503# approximately constant number of tokens (and variable number of sequences in a504# batch), other models solve imbalance by bucketing samples with similar505# sequence length or even by sorting dataset by sequence length.506507###############################################################################508# Conclusion509# ----------510#511# This tutorial covered a comprehensive set of performance optimization techniques512# for PyTorch models. The key takeaways include:513#514# * **General optimizations**: Enable async data loading, disable gradients for515# inference, fuse operations with ``torch.compile``, and use efficient memory formats516# * **CPU optimizations**: Leverage NUMA controls, optimize OpenMP settings, and517# use efficient memory allocators518# * **GPU optimizations**: Enable Tensor cores, use CUDA graphs, enable cuDNN519# autotuner, and implement mixed precision training520# * **Distributed optimizations**: Use DistributedDataParallel, optimize gradient521# synchronization, and balance workloads across devices522#523# Many of these optimizations can be applied with minimal code changes and provide524# significant performance improvements across a wide range of deep learning models.525#526# Further Reading527# ---------------528#529# * `PyTorch Performance Tuning Documentation <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_530# * `CUDA Best Practices <https://pytorch.org/docs/stable/notes/cuda.html>`_531# * `Distributed Training Documentation <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_532# * `Mixed Precision Training <https://pytorch.org/docs/stable/amp.html>`_533# * `torch.compile Tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_534535536