Path: blob/main/unstable_source/gpu_direct_storage.py
1686 views
"""1(prototype) Accelerating ``torch.save`` and ``torch.load`` with GPUDirect Storage2=================================================================================34GPUDirect Storage enables a direct data path for direct memory access transfers5between GPU memory and storage, avoiding a bounce buffer through the CPU.67In version **2.7**, we introduced new prototype APIs to ``torch.cuda.gds`` that serve as thin wrappers around8the `cuFile APIs <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api>`_9that can be used with ``torch.Tensor`` to achieve improved I/O performance.1011In this tutorial, we will demonstrate how to use the ``torch.cuda.gds`` APIs in conjunction with12checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem.1314.. grid:: 21516.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn17:class-card: card-prerequisites1819* Understand how to use the ``torch.cuda.gds`` APIs in conjunction with20checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem2122.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites23:class-card: card-prerequisites2425* PyTorch v.2.7.0 or later26* GPUDirect Storage must be installed per27`the documentation <https://docs.nvidia.com/gpudirect-storage/troubleshooting-guide/contents.html>`_28* Ensure that the filesystem that you are saving/loading to supports GPUDirect Storage.29"""3031################################################################################32# Using GPUDirect Storage with ``torch.save`` and ``torch.load``33# ------------------------------------------------------------------------------------34# GPUDirect Storage requires a storage alignment of 4KB. You can toggle this by using35# ``torch.utils.serialization.config.save.storage_alignment``:3637import torch38from torch.utils.serialization import config as serialization_config3940serialization_config.save.storage_alignment = 40964142################################################################################43# The steps involved in the process are as follows:44# * Write the checkpoint file without any actual data. This reserves the space on disk.45# * Read the offsets for the storage associated with each tensor in the checkpoint using ``FakeTensor``.46# * Use ``GDSFile`` to write the appropriate data at these offsets.47#48# Given a state dictionary of tensors that are on the GPU, one can use the ``torch.serialization.skip_data`` context49# manager to save a checkpoint that contains all relevant metadata except the storage bytes. For each ``torch.Storage``50# in the state dictionary, space will be reserved within the checkpoint for the storage bytes.5152import torch.nn as nn5354m = nn.Linear(5, 10, device='cuda')55sd = m.state_dict()5657with torch.serialization.skip_data():58torch.save(sd, "checkpoint.pt")5960################################################################################61# We can get the offsets that each storage should be written to within the checkpoint by loading under62# a ``FakeTensorMode``. A FakeTensor is a tensor that has metadata (such as sizes, strides, dtype, device)63# information about the tensor but does not have any storage bytes. The following snippet will not materialize64# any data but will tag each ``FakeTensor`` with the offset within the checkpoint that65# corresponds to the tensor.66#67# If you are continuously saving the same state dictionary during training, you68# would only need to obtain the offsets once and the same offsets can be re-used. Similarly if tensor is going to69# be saved or loaded to repeatedly you can use the ``torch.cuda.gds.gds_register_buffer`` which wraps70# ``cuFileBufRegister`` to register the storages as GDS buffers.71#72# Note that ``torch.cuda.gds.GdsFile.save_storage`` binds to the synchronous ``cuFileWrite`` API,73# so no synchronization is needed afterwards.747576import os77from torch._subclasses.fake_tensor import FakeTensorMode7879with FakeTensorMode() as mode:80fake_sd = torch.load("checkpoint.pt")8182for k, v in fake_sd.items():83print(f"key={k}, offset={v.untyped_storage()._checkpoint_offset}")8485f = torch.cuda.gds.GdsFile("checkpoint.pt", os.O_RDWR)8687for k, v in sd.items():88offset = fake_sd[k].untyped_storage()._checkpoint_offset89# save_storage is a wrapper around `cuFileWrite`90f.save_storage(v.untyped_storage(), offset)919293################################################################################94# We verify correctness of the saved checkpoint by ``torch.load`` and comparing.9596sd_loaded = torch.load("checkpoint.pt")97for k, v in sd_loaded.items():98assert torch.equal(v, sd[k])99100################################################################################101# The loading flow is the inverse: you can use ``torch.load`` with the ``torch.serialization.skip_data`` context102# manager to load everything except the storage bytes. This means that any tensors in the checkpoint will be103# created but their storages will be empty (as if the tensors were created via ``torch.empty``).104105with torch.serialization.skip_data():106sd_loaded = torch.load("checkpoint.pt")107108################################################################################109# We once again use the ``FakeTensorMode`` to get the checkpoint offsets and110# ascertain that the loaded checkpoint is the same as the saved checkpoint.111#112# Similar to ``torch.cuda.gds.GdsFile.save_storage``, ``torch.cuda.gds.GdsFile.load_storage``113# binds to the synchronous ``cuFileRead`` API, so no synchronization is needed afterwards.114115for k, v in sd_loaded.items():116assert not torch.equal(v, sd[k])117offset = fake_sd[k].untyped_storage()._checkpoint_offset118# load_storage is a wrapper around `cuFileRead`119f.load_storage(v.untyped_storage(), offset)120121for k, v in sd_loaded.items():122assert torch.equal(v, sd[k])123124del f125##########################################################126# Conclusion127# ==========128#129# In this tutorial we have demonstrated how to use the prototype ``torch.cuda.gds`` APIs130# in conjunction with ``torch.save`` and ``torch.load`` on local filesystem. Please131# file an issue in the PyTorch GitHub repo if you have any feedback.132133134