Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/recipes/tuning_guide.py
1694 views
1
"""
2
Performance Tuning Guide
3
*************************
4
**Author**: `Szymon Migacz <https://github.com/szmigacz>`_
5
6
Performance Tuning Guide is a set of optimizations and best practices which can
7
accelerate training and inference of deep learning models in PyTorch. Presented
8
techniques often can be implemented by changing only a few lines of code and can
9
be applied to a wide range of deep learning models across all domains.
10
11
.. grid:: 2
12
13
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
14
:class-card: card-prerequisites
15
16
* General optimization techniques for PyTorch models
17
* CPU-specific performance optimizations
18
* GPU acceleration strategies
19
* Distributed training optimizations
20
21
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
22
:class-card: card-prerequisites
23
24
* PyTorch 2.0 or later
25
* Python 3.8 or later
26
* CUDA-capable GPU (recommended for GPU optimizations)
27
* Linux, macOS, or Windows operating system
28
29
Overview
30
--------
31
32
Performance optimization is crucial for efficient deep learning model training and inference.
33
This tutorial covers a comprehensive set of techniques to accelerate PyTorch workloads across
34
different hardware configurations and use cases.
35
36
General optimizations
37
---------------------
38
"""
39
40
import torch
41
import torchvision
42
43
###############################################################################
44
# Enable asynchronous data loading and augmentation
45
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
46
# `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_
47
# supports asynchronous data loading and data augmentation in separate worker
48
# subprocesses. The default setting for ``DataLoader`` is ``num_workers=0``,
49
# which means that the data loading is synchronous and done in the main process.
50
# As a result the main training process has to wait for the data to be available
51
# to continue the execution.
52
#
53
# Setting ``num_workers > 0`` enables asynchronous data loading and overlap
54
# between the training and data loading. ``num_workers`` should be tuned
55
# depending on the workload, CPU, GPU, and location of training data.
56
#
57
# ``DataLoader`` accepts ``pin_memory`` argument, which defaults to ``False``.
58
# When using a GPU it's better to set ``pin_memory=True``, this instructs
59
# ``DataLoader`` to use pinned memory and enables faster and asynchronous memory
60
# copy from the host to the GPU.
61
62
###############################################################################
63
# Disable gradient calculation for validation or inference
64
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
65
# PyTorch saves intermediate buffers from all operations which involve tensors
66
# that require gradients. Typically gradients aren't needed for validation or
67
# inference.
68
# `torch.no_grad() <https://pytorch.org/docs/stable/generated/torch.no_grad.html#torch.no_grad>`_
69
# context manager can be applied to disable gradient calculation within a
70
# specified block of code, this accelerates execution and reduces the amount of
71
# required memory.
72
# `torch.no_grad() <https://pytorch.org/docs/stable/generated/torch.no_grad.html#torch.no_grad>`_
73
# can also be used as a function decorator.
74
75
###############################################################################
76
# Disable bias for convolutions directly followed by a batch norm
77
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
78
# `torch.nn.Conv2d() <https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d>`_
79
# has ``bias`` parameter which defaults to ``True`` (the same is true for
80
# `Conv1d <https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d>`_
81
# and
82
# `Conv3d <https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#torch.nn.Conv3d>`_
83
# ).
84
#
85
# If a ``nn.Conv2d`` layer is directly followed by a ``nn.BatchNorm2d`` layer,
86
# then the bias in the convolution is not needed, instead use
87
# ``nn.Conv2d(..., bias=False, ....)``. Bias is not needed because in the first
88
# step ``BatchNorm`` subtracts the mean, which effectively cancels out the
89
# effect of bias.
90
#
91
# This is also applicable to 1d and 3d convolutions as long as ``BatchNorm`` (or
92
# other normalization layer) normalizes on the same dimension as convolution's
93
# bias.
94
#
95
# Models available from `torchvision <https://github.com/pytorch/vision>`_
96
# already implement this optimization.
97
98
###############################################################################
99
# Use parameter.grad = None instead of model.zero_grad() or optimizer.zero_grad()
100
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
101
# Instead of calling:
102
model.zero_grad()
103
# or
104
optimizer.zero_grad()
105
106
###############################################################################
107
# to zero out gradients, use the following method instead:
108
109
for param in model.parameters():
110
param.grad = None
111
112
###############################################################################
113
# The second code snippet does not zero the memory of each individual parameter,
114
# also the subsequent backward pass uses assignment instead of addition to store
115
# gradients, this reduces the number of memory operations.
116
#
117
# Setting gradient to ``None`` has a slightly different numerical behavior than
118
# setting it to zero, for more details refer to the
119
# `documentation <https://pytorch.org/docs/master/optim.html#torch.optim.Optimizer.zero_grad>`_.
120
#
121
# Alternatively, call ``model`` or
122
# ``optimizer.zero_grad(set_to_none=True)``.
123
124
###############################################################################
125
# Fuse operations
126
# ~~~~~~~~~~~~~~~~~~~~~~~~~
127
# Pointwise operations such as elementwise addition, multiplication, and math
128
# functions like `sin()`, `cos()`, `sigmoid()`, etc., can be combined into a
129
# single kernel. This fusion helps reduce memory access and kernel launch times.
130
# Typically, pointwise operations are memory-bound; PyTorch eager-mode initiates
131
# a separate kernel for each operation, which involves loading data from memory,
132
# executing the operation (often not the most time-consuming step), and writing
133
# the results back to memory.
134
#
135
# By using a fused operator, only one kernel is launched for multiple pointwise
136
# operations, and data is loaded and stored just once. This efficiency is
137
# particularly beneficial for activation functions, optimizers, and custom RNN cells etc.
138
#
139
# PyTorch 2 introduces a compile-mode facilitated by TorchInductor, an underlying compiler
140
# that automatically fuses kernels. TorchInductor extends its capabilities beyond simple
141
# element-wise operations, enabling advanced fusion of eligible pointwise and reduction
142
# operations for optimized performance.
143
#
144
# In the simplest case fusion can be enabled by applying
145
# `torch.compile <https://pytorch.org/docs/stable/generated/torch.compile.html>`_
146
# decorator to the function definition, for example:
147
148
@torch.compile
149
def gelu(x):
150
return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
151
152
###############################################################################
153
# Refer to
154
# `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
155
# for more advanced use cases.
156
157
###############################################################################
158
# Enable channels_last memory format for computer vision models
159
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
160
# PyTorch supports ``channels_last`` memory format for
161
# convolutional networks. This format is meant to be used in conjunction with
162
# `AMP <https://pytorch.org/docs/stable/amp.html>`_ to further accelerate
163
# convolutional neural networks with
164
# `Tensor Cores <https://www.nvidia.com/en-us/data-center/tensor-cores/>`_.
165
#
166
# Support for ``channels_last`` is experimental, but it's expected to work for
167
# standard computer vision models (e.g. ResNet-50, SSD). To convert models to
168
# ``channels_last`` format follow
169
# `Channels Last Memory Format Tutorial <https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html>`_.
170
# The tutorial includes a section on
171
# `converting existing models <https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html#converting-existing-models>`_.
172
173
###############################################################################
174
# Checkpoint intermediate buffers
175
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
176
# Buffer checkpointing is a technique to mitigate the memory capacity burden of
177
# model training. Instead of storing inputs of all layers to compute upstream
178
# gradients in backward propagation, it stores the inputs of a few layers and
179
# the others are recomputed during backward pass. The reduced memory
180
# requirements enables increasing the batch size that can improve utilization.
181
#
182
# Checkpointing targets should be selected carefully. The best is not to store
183
# large layer outputs that have small re-computation cost. The example target
184
# layers are activation functions (e.g. ``ReLU``, ``Sigmoid``, ``Tanh``),
185
# up/down sampling and matrix-vector operations with small accumulation depth.
186
#
187
# PyTorch supports a native
188
# `torch.utils.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`_
189
# API to automatically perform checkpointing and recomputation.
190
191
###############################################################################
192
# Disable debugging APIs
193
# ~~~~~~~~~~~~~~~~~~~~~~
194
# Many PyTorch APIs are intended for debugging and should be disabled for
195
# regular training runs:
196
#
197
# * anomaly detection:
198
# `torch.autograd.detect_anomaly <https://pytorch.org/docs/stable/autograd.html#torch.autograd.detect_anomaly>`_
199
# or
200
# `torch.autograd.set_detect_anomaly(True) <https://pytorch.org/docs/stable/autograd.html#torch.autograd.set_detect_anomaly>`_
201
# * profiler related:
202
# `torch.autograd.profiler.emit_nvtx <https://pytorch.org/docs/stable/autograd.html#torch.autograd.profiler.emit_nvtx>`_,
203
# `torch.autograd.profiler.profile <https://pytorch.org/docs/stable/autograd.html#torch.autograd.profiler.profile>`_
204
# * autograd ``gradcheck``:
205
# `torch.autograd.gradcheck <https://pytorch.org/docs/stable/autograd.html#torch.autograd.gradcheck>`_
206
# or
207
# `torch.autograd.gradgradcheck <https://pytorch.org/docs/stable/autograd.html#torch.autograd.gradgradcheck>`_
208
#
209
210
###############################################################################
211
# CPU specific optimizations
212
# --------------------------
213
214
###############################################################################
215
# Utilize Non-Uniform Memory Access (NUMA) Controls
216
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
217
# NUMA or non-uniform memory access is a memory layout design used in data center machines meant to take advantage of locality of memory in multi-socket machines with multiple memory controllers and blocks. Generally speaking, all deep learning workloads, training or inference, get better performance without accessing hardware resources across NUMA nodes. Thus, inference can be run with multiple instances, each instance runs on one socket, to raise throughput. For training tasks on single node, distributed training is recommended to make each training process run on one socket.
218
#
219
# In general cases the following command executes a PyTorch script on cores on the Nth node only, and avoids cross-socket memory access to reduce memory access overhead.
220
#
221
# .. code-block:: sh
222
#
223
# numactl --cpunodebind=N --membind=N python <pytorch_script>
224
225
###############################################################################
226
# More detailed descriptions can be found `here <https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html>`_.
227
228
###############################################################################
229
# Utilize OpenMP
230
# ~~~~~~~~~~~~~~
231
# OpenMP is utilized to bring better performance for parallel computation tasks.
232
# ``OMP_NUM_THREADS`` is the easiest switch that can be used to accelerate computations. It determines number of threads used for OpenMP computations.
233
# CPU affinity setting controls how workloads are distributed over multiple cores. It affects communication overhead, cache line invalidation overhead, or page thrashing, thus proper setting of CPU affinity brings performance benefits. ``GOMP_CPU_AFFINITY`` or ``KMP_AFFINITY`` determines how to bind OpenMP* threads to physical processing units. Detailed information can be found `here <https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html>`_.
234
235
###############################################################################
236
# With the following command, PyTorch run the task on N OpenMP threads.
237
#
238
# .. code-block:: sh
239
#
240
# export OMP_NUM_THREADS=N
241
242
###############################################################################
243
# Typically, the following environment variables are used to set for CPU affinity with GNU OpenMP implementation. ``OMP_PROC_BIND`` specifies whether threads may be moved between processors. Setting it to CLOSE keeps OpenMP threads close to the primary thread in contiguous place partitions. ``OMP_SCHEDULE`` determines how OpenMP threads are scheduled. ``GOMP_CPU_AFFINITY`` binds threads to specific CPUs.
244
# An important tuning parameter is core pinning which prevent the threads of migrating between multiple CPUs, enhancing data location and minimizing inter core communication.
245
#
246
# .. code-block:: sh
247
#
248
# export OMP_SCHEDULE=STATIC
249
# export OMP_PROC_BIND=CLOSE
250
# export GOMP_CPU_AFFINITY="N-M"
251
252
###############################################################################
253
# Intel OpenMP Runtime Library (``libiomp``)
254
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
255
# By default, PyTorch uses GNU OpenMP (GNU ``libgomp``) for parallel computation. On Intel platforms, Intel OpenMP Runtime Library (``libiomp``) provides OpenMP API specification support. It sometimes brings more performance benefits compared to ``libgomp``. Utilizing environment variable ``LD_PRELOAD`` can switch OpenMP library to ``libiomp``:
256
#
257
# .. code-block:: sh
258
#
259
# export LD_PRELOAD=<path>/libiomp5.so:$LD_PRELOAD
260
261
###############################################################################
262
# Similar to CPU affinity settings in GNU OpenMP, environment variables are provided in ``libiomp`` to control CPU affinity settings.
263
# ``KMP_AFFINITY`` binds OpenMP threads to physical processing units. ``KMP_BLOCKTIME`` sets the time, in milliseconds, that a thread should wait, after completing the execution of a parallel region, before sleeping. In most cases, setting ``KMP_BLOCKTIME`` to 1 or 0 yields good performances.
264
# The following commands show a common settings with Intel OpenMP Runtime Library.
265
#
266
# .. code-block:: sh
267
#
268
# export KMP_AFFINITY=granularity=fine,compact,1,0
269
# export KMP_BLOCKTIME=1
270
271
###############################################################################
272
# Switch Memory allocator
273
# ~~~~~~~~~~~~~~~~~~~~~~~
274
# For deep learning workloads, ``Jemalloc`` or ``TCMalloc`` can get better performance by reusing memory as much as possible than default ``malloc`` function. `Jemalloc <https://github.com/jemalloc/jemalloc>`_ is a general purpose ``malloc`` implementation that emphasizes fragmentation avoidance and scalable concurrency support. `TCMalloc <https://google.github.io/tcmalloc/overview.html>`_ also features a couple of optimizations to speed up program executions. One of them is holding memory in caches to speed up access of commonly-used objects. Holding such caches even after deallocation also helps avoid costly system calls if such memory is later re-allocated.
275
# Use environment variable ``LD_PRELOAD`` to take advantage of one of them.
276
#
277
# .. code-block:: sh
278
#
279
# export LD_PRELOAD=<jemalloc.so/tcmalloc.so>:$LD_PRELOAD
280
281
282
###############################################################################
283
# Train a model on CPU with PyTorch ``DistributedDataParallel``(DDP) functionality
284
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
# For small scale models or memory-bound models, such as DLRM, training on CPU is also a good choice. On a machine with multiple sockets, distributed training brings a high-efficient hardware resource usage to accelerate the training process. `Torch-ccl <https://github.com/intel/torch-ccl>`_, optimized with Intel(R) ``oneCCL`` (collective communications library) for efficient distributed deep learning training implementing such collectives like ``allreduce``, ``allgather``, ``alltoall``, implements PyTorch C10D ``ProcessGroup`` API and can be dynamically loaded as external ``ProcessGroup``. Upon optimizations implemented in PyTorch DDP module, ``torch-ccl`` accelerates communication operations. Beside the optimizations made to communication kernels, ``torch-ccl`` also features simultaneous computation-communication functionality.
286
287
###############################################################################
288
# GPU specific optimizations
289
# --------------------------
290
291
###############################################################################
292
# Enable Tensor cores
293
# ~~~~~~~~~~~~~~~~~~~~~~~
294
# Tensor cores are specialized hardware designed to compute matrix-matrix multiplication
295
# operations, primarily utilized in deep learning and AI workloads. Tensor cores have
296
# specific precision requirements which can be adjusted manually or via the Automatic
297
# Mixed Precision API.
298
#
299
# In particular, tensor operations take advantage of lower precision workloads.
300
# Which can be controlled via ``torch.set_float32_matmul_precision``.
301
# The default format is set to 'highest,' which utilizes the tensor data type.
302
# However, PyTorch offers alternative precision settings: 'high' and 'medium.'
303
# These options prioritize computational speed over numerical precision."
304
305
###############################################################################
306
# Use CUDA Graphs
307
# ~~~~~~~~~~~~~~~~~~~~~~~
308
# At the time of using a GPU, work first must be launched from the CPU and
309
# in some cases the context switch between CPU and GPU can lead to bad resource
310
# utilization. CUDA graphs are a way to keep computation within the GPU without
311
# paying the extra cost of kernel launches and host synchronization.
312
313
# It can be enabled using
314
torch.compile(m, "reduce-overhead")
315
# or
316
torch.compile(m, "max-autotune")
317
318
###############################################################################
319
# Support for CUDA graph is in development, and its usage can incur in increased
320
# device memory consumption and some models might not compile.
321
322
###############################################################################
323
# Enable cuDNN auto-tuner
324
# ~~~~~~~~~~~~~~~~~~~~~~~
325
# `NVIDIA cuDNN <https://developer.nvidia.com/cudnn>`_ supports many algorithms
326
# to compute a convolution. Autotuner runs a short benchmark and selects the
327
# kernel with the best performance on a given hardware for a given input size.
328
#
329
# For convolutional networks (other types currently not supported), enable cuDNN
330
# autotuner before launching the training loop by setting:
331
332
torch.backends.cudnn.benchmark = True
333
###############################################################################
334
#
335
# * the auto-tuner decisions may be non-deterministic; different algorithm may
336
# be selected for different runs. For more details see
337
# `PyTorch: Reproducibility <https://pytorch.org/docs/stable/notes/randomness.html?highlight=determinism>`_
338
# * in some rare cases, such as with highly variable input sizes, it's better
339
# to run convolutional networks with autotuner disabled to avoid the overhead
340
# associated with algorithm selection for each input size.
341
#
342
343
###############################################################################
344
# Avoid unnecessary CPU-GPU synchronization
345
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
346
# Avoid unnecessary synchronizations, to let the CPU run ahead of the
347
# accelerator as much as possible to make sure that the accelerator work queue
348
# contains many operations.
349
#
350
# When possible, avoid operations which require synchronizations, for example:
351
#
352
# * ``print(cuda_tensor)``
353
# * ``cuda_tensor.item()``
354
# * memory copies: ``tensor.cuda()``, ``cuda_tensor.cpu()`` and equivalent
355
# ``tensor.to(device)`` calls
356
# * ``cuda_tensor.nonzero()``
357
# * python control flow which depends on results of operations performed on CUDA
358
# tensors e.g. ``if (cuda_tensor != 0).all()``
359
#
360
361
###############################################################################
362
# Create tensors directly on the target device
363
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
364
# Instead of calling ``torch.rand(size).cuda()`` to generate a random tensor,
365
# produce the output directly on the target device:
366
# ``torch.rand(size, device='cuda')``.
367
#
368
# This is applicable to all functions which create new tensors and accept
369
# ``device`` argument:
370
# `torch.rand() <https://pytorch.org/docs/stable/generated/torch.rand.html#torch.rand>`_,
371
# `torch.zeros() <https://pytorch.org/docs/stable/generated/torch.zeros.html#torch.zeros>`_,
372
# `torch.full() <https://pytorch.org/docs/stable/generated/torch.full.html#torch.full>`_
373
# and similar.
374
375
###############################################################################
376
# Use mixed precision and AMP
377
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
378
# Mixed precision leverages
379
# `Tensor Cores <https://www.nvidia.com/en-us/data-center/tensor-cores/>`_
380
# and offers up to 3x overall speedup on Volta and newer GPU architectures. To
381
# use Tensor Cores AMP should be enabled and matrix/tensor dimensions should
382
# satisfy requirements for calling kernels that use Tensor Cores.
383
#
384
# To use Tensor Cores:
385
#
386
# * set sizes to multiples of 8 (to map onto dimensions of Tensor Cores)
387
#
388
# * see
389
# `Deep Learning Performance Documentation
390
# <https://docs.nvidia.com/deeplearning/performance/index.html#optimizing-performance>`_
391
# for more details and guidelines specific to layer type
392
# * if layer size is derived from other parameters rather than fixed, it can
393
# still be explicitly padded e.g. vocabulary size in NLP models
394
#
395
# * enable AMP
396
#
397
# * Introduction to Mixed Precision Training and AMP:
398
# `slides <https://nvlabs.github.io/eccv2020-mixed-precision-tutorial/files/dusan_stosic-training-neural-networks-with-tensor-cores.pdf>`_
399
# * native PyTorch AMP is available:
400
# `documentation <https://pytorch.org/docs/stable/amp.html>`_,
401
# `examples <https://pytorch.org/docs/stable/notes/amp_examples.html#amp-examples>`_,
402
# `tutorial <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_
403
#
404
#
405
406
###############################################################################
407
# Preallocate memory in case of variable input length
408
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
409
# Models for speech recognition or for NLP are often trained on input tensors
410
# with variable sequence length. Variable length can be problematic for PyTorch
411
# caching allocator and can lead to reduced performance or to unexpected
412
# out-of-memory errors. If a batch with a short sequence length is followed by
413
# an another batch with longer sequence length, then PyTorch is forced to
414
# release intermediate buffers from previous iteration and to re-allocate new
415
# buffers. This process is time consuming and causes fragmentation in the
416
# caching allocator which may result in out-of-memory errors.
417
#
418
# A typical solution is to implement preallocation. It consists of the
419
# following steps:
420
#
421
# #. generate a (usually random) batch of inputs with maximum sequence length
422
# (either corresponding to max length in the training dataset or to some
423
# predefined threshold)
424
# #. execute a forward and a backward pass with the generated batch, do not
425
# execute an optimizer or a learning rate scheduler, this step preallocates
426
# buffers of maximum size, which can be reused in subsequent
427
# training iterations
428
# #. zero out gradients
429
# #. proceed to regular training
430
#
431
432
###############################################################################
433
# Distributed optimizations
434
# -------------------------
435
436
###############################################################################
437
# Use efficient data-parallel backend
438
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
439
# PyTorch has two ways to implement data-parallel training:
440
#
441
# * `torch.nn.DataParallel <https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html#torch.nn.DataParallel>`_
442
# * `torch.nn.parallel.DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_
443
#
444
# ``DistributedDataParallel`` offers much better performance and scaling to
445
# multiple-GPUs. For more information refer to the
446
# `relevant section of CUDA Best Practices <https://pytorch.org/docs/stable/notes/cuda.html#use-nn-parallel-distributeddataparallel-instead-of-multiprocessing-or-nn-dataparallel>`_
447
# from PyTorch documentation.
448
449
###############################################################################
450
# Skip unnecessary all-reduce if training with ``DistributedDataParallel`` and gradient accumulation
451
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
452
# By default
453
# `torch.nn.parallel.DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_
454
# executes gradient all-reduce after every backward pass to compute the average
455
# gradient over all workers participating in the training. If training uses
456
# gradient accumulation over N steps, then all-reduce is not necessary after
457
# every training step, it's only required to perform all-reduce after the last
458
# call to backward, just before the execution of the optimizer.
459
#
460
# ``DistributedDataParallel`` provides
461
# `no_sync() <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync>`_
462
# context manager which disables gradient all-reduce for particular iteration.
463
# ``no_sync()`` should be applied to first ``N-1`` iterations of gradient
464
# accumulation, the last iteration should follow the default execution and
465
# perform the required gradient all-reduce.
466
467
###############################################################################
468
# Match the order of layers in constructors and during the execution if using ``DistributedDataParallel(find_unused_parameters=True)``
469
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
470
# `torch.nn.parallel.DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_
471
# with ``find_unused_parameters=True`` uses the order of layers and parameters
472
# from model constructors to build buckets for ``DistributedDataParallel``
473
# gradient all-reduce. ``DistributedDataParallel`` overlaps all-reduce with the
474
# backward pass. All-reduce for a particular bucket is asynchronously triggered
475
# only when all gradients for parameters in a given bucket are available.
476
#
477
# To maximize the amount of overlap, the order in model constructors should
478
# roughly match the order during the execution. If the order doesn't match, then
479
# all-reduce for the entire bucket waits for the gradient which is the last to
480
# arrive, this may reduce the overlap between backward pass and all-reduce,
481
# all-reduce may end up being exposed, which slows down the training.
482
#
483
# ``DistributedDataParallel`` with ``find_unused_parameters=False`` (which is
484
# the default setting) relies on automatic bucket formation based on order of
485
# operations encountered during the backward pass. With
486
# ``find_unused_parameters=False`` it's not necessary to reorder layers or
487
# parameters to achieve optimal performance.
488
489
###############################################################################
490
# Load-balance workload in a distributed setting
491
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
492
# Load imbalance typically may happen for models processing sequential data
493
# (speech recognition, translation, language models etc.). If one device
494
# receives a batch of data with sequence length longer than sequence lengths for
495
# the remaining devices, then all devices wait for the worker which finishes
496
# last. Backward pass functions as an implicit synchronization point in a
497
# distributed setting with
498
# `DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_
499
# backend.
500
#
501
# There are multiple ways to solve the load balancing problem. The core idea is
502
# to distribute workload over all workers as uniformly as possible within each
503
# global batch. For example Transformer solves imbalance by forming batches with
504
# approximately constant number of tokens (and variable number of sequences in a
505
# batch), other models solve imbalance by bucketing samples with similar
506
# sequence length or even by sorting dataset by sequence length.
507
508
###############################################################################
509
# Conclusion
510
# ----------
511
#
512
# This tutorial covered a comprehensive set of performance optimization techniques
513
# for PyTorch models. The key takeaways include:
514
#
515
# * **General optimizations**: Enable async data loading, disable gradients for
516
# inference, fuse operations with ``torch.compile``, and use efficient memory formats
517
# * **CPU optimizations**: Leverage NUMA controls, optimize OpenMP settings, and
518
# use efficient memory allocators
519
# * **GPU optimizations**: Enable Tensor cores, use CUDA graphs, enable cuDNN
520
# autotuner, and implement mixed precision training
521
# * **Distributed optimizations**: Use DistributedDataParallel, optimize gradient
522
# synchronization, and balance workloads across devices
523
#
524
# Many of these optimizations can be applied with minimal code changes and provide
525
# significant performance improvements across a wide range of deep learning models.
526
#
527
# Further Reading
528
# ---------------
529
#
530
# * `PyTorch Performance Tuning Documentation <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_
531
# * `CUDA Best Practices <https://pytorch.org/docs/stable/notes/cuda.html>`_
532
# * `Distributed Training Documentation <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_
533
# * `Mixed Precision Training <https://pytorch.org/docs/stable/amp.html>`_
534
# * `torch.compile Tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
535
536