Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/unstable_source/gpu_direct_storage.py
1686 views
1
"""
2
(prototype) Accelerating ``torch.save`` and ``torch.load`` with GPUDirect Storage
3
=================================================================================
4
5
GPUDirect Storage enables a direct data path for direct memory access transfers
6
between GPU memory and storage, avoiding a bounce buffer through the CPU.
7
8
In version **2.7**, we introduced new prototype APIs to ``torch.cuda.gds`` that serve as thin wrappers around
9
the `cuFile APIs <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api>`_
10
that can be used with ``torch.Tensor`` to achieve improved I/O performance.
11
12
In this tutorial, we will demonstrate how to use the ``torch.cuda.gds`` APIs in conjunction with
13
checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem.
14
15
.. grid:: 2
16
17
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
18
:class-card: card-prerequisites
19
20
* Understand how to use the ``torch.cuda.gds`` APIs in conjunction with
21
checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem
22
23
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
24
:class-card: card-prerequisites
25
26
* PyTorch v.2.7.0 or later
27
* GPUDirect Storage must be installed per
28
`the documentation <https://docs.nvidia.com/gpudirect-storage/troubleshooting-guide/contents.html>`_
29
* Ensure that the filesystem that you are saving/loading to supports GPUDirect Storage.
30
"""
31
32
################################################################################
33
# Using GPUDirect Storage with ``torch.save`` and ``torch.load``
34
# ------------------------------------------------------------------------------------
35
# GPUDirect Storage requires a storage alignment of 4KB. You can toggle this by using
36
# ``torch.utils.serialization.config.save.storage_alignment``:
37
38
import torch
39
from torch.utils.serialization import config as serialization_config
40
41
serialization_config.save.storage_alignment = 4096
42
43
################################################################################
44
# The steps involved in the process are as follows:
45
# * Write the checkpoint file without any actual data. This reserves the space on disk.
46
# * Read the offsets for the storage associated with each tensor in the checkpoint using ``FakeTensor``.
47
# * Use ``GDSFile`` to write the appropriate data at these offsets.
48
#
49
# Given a state dictionary of tensors that are on the GPU, one can use the ``torch.serialization.skip_data`` context
50
# manager to save a checkpoint that contains all relevant metadata except the storage bytes. For each ``torch.Storage``
51
# in the state dictionary, space will be reserved within the checkpoint for the storage bytes.
52
53
import torch.nn as nn
54
55
m = nn.Linear(5, 10, device='cuda')
56
sd = m.state_dict()
57
58
with torch.serialization.skip_data():
59
torch.save(sd, "checkpoint.pt")
60
61
################################################################################
62
# We can get the offsets that each storage should be written to within the checkpoint by loading under
63
# a ``FakeTensorMode``. A FakeTensor is a tensor that has metadata (such as sizes, strides, dtype, device)
64
# information about the tensor but does not have any storage bytes. The following snippet will not materialize
65
# any data but will tag each ``FakeTensor`` with the offset within the checkpoint that
66
# corresponds to the tensor.
67
#
68
# If you are continuously saving the same state dictionary during training, you
69
# would only need to obtain the offsets once and the same offsets can be re-used. Similarly if tensor is going to
70
# be saved or loaded to repeatedly you can use the ``torch.cuda.gds.gds_register_buffer`` which wraps
71
# ``cuFileBufRegister`` to register the storages as GDS buffers.
72
#
73
# Note that ``torch.cuda.gds.GdsFile.save_storage`` binds to the synchronous ``cuFileWrite`` API,
74
# so no synchronization is needed afterwards.
75
76
77
import os
78
from torch._subclasses.fake_tensor import FakeTensorMode
79
80
with FakeTensorMode() as mode:
81
fake_sd = torch.load("checkpoint.pt")
82
83
for k, v in fake_sd.items():
84
print(f"key={k}, offset={v.untyped_storage()._checkpoint_offset}")
85
86
f = torch.cuda.gds.GdsFile("checkpoint.pt", os.O_RDWR)
87
88
for k, v in sd.items():
89
offset = fake_sd[k].untyped_storage()._checkpoint_offset
90
# save_storage is a wrapper around `cuFileWrite`
91
f.save_storage(v.untyped_storage(), offset)
92
93
94
################################################################################
95
# We verify correctness of the saved checkpoint by ``torch.load`` and comparing.
96
97
sd_loaded = torch.load("checkpoint.pt")
98
for k, v in sd_loaded.items():
99
assert torch.equal(v, sd[k])
100
101
################################################################################
102
# The loading flow is the inverse: you can use ``torch.load`` with the ``torch.serialization.skip_data`` context
103
# manager to load everything except the storage bytes. This means that any tensors in the checkpoint will be
104
# created but their storages will be empty (as if the tensors were created via ``torch.empty``).
105
106
with torch.serialization.skip_data():
107
sd_loaded = torch.load("checkpoint.pt")
108
109
################################################################################
110
# We once again use the ``FakeTensorMode`` to get the checkpoint offsets and
111
# ascertain that the loaded checkpoint is the same as the saved checkpoint.
112
#
113
# Similar to ``torch.cuda.gds.GdsFile.save_storage``, ``torch.cuda.gds.GdsFile.load_storage``
114
# binds to the synchronous ``cuFileRead`` API, so no synchronization is needed afterwards.
115
116
for k, v in sd_loaded.items():
117
assert not torch.equal(v, sd[k])
118
offset = fake_sd[k].untyped_storage()._checkpoint_offset
119
# load_storage is a wrapper around `cuFileRead`
120
f.load_storage(v.untyped_storage(), offset)
121
122
for k, v in sd_loaded.items():
123
assert torch.equal(v, sd[k])
124
125
del f
126
##########################################################
127
# Conclusion
128
# ==========
129
#
130
# In this tutorial we have demonstrated how to use the prototype ``torch.cuda.gds`` APIs
131
# in conjunction with ``torch.save`` and ``torch.load`` on local filesystem. Please
132
# file an issue in the PyTorch GitHub repo if you have any feedback.
133
134