Path: blob/main/beginner_source/saving_loading_models.py
1686 views
# -*- coding: utf-8 -*-1"""2Saving and Loading Models3=========================4**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_56This document provides solutions to a variety of use cases regarding the7saving and loading of PyTorch models. Feel free to read the whole8document, or just skip to the code you need for a desired use case.910When it comes to saving and loading models, there are three core11functions to be familiar with:12131) `torch.save <https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save>`__:14Saves a serialized object to disk. This function uses Python’s15`pickle <https://docs.python.org/3/library/pickle.html>`__ utility16for serialization. Models, tensors, and dictionaries of all kinds of17objects can be saved using this function.18192) `torch.load <https://pytorch.org/docs/stable/torch.html?highlight=torch%20load#torch.load>`__:20Uses `pickle <https://docs.python.org/3/library/pickle.html>`__\ ’s21unpickling facilities to deserialize pickled object files to memory.22This function also facilitates the device to load the data into (see23`Saving & Loading Model Across24Devices <#saving-loading-model-across-devices>`__).25263) `torch.nn.Module.load_state_dict <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict>`__:27Loads a model’s parameter dictionary using a deserialized28*state_dict*. For more information on *state_dict*, see `What is a29state_dict? <#what-is-a-state-dict>`__.30313233**Contents:**3435- `What is a state_dict? <#what-is-a-state-dict>`__36- `Saving & Loading Model for37Inference <#saving-loading-model-for-inference>`__38- `Saving & Loading a General39Checkpoint <#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training>`__40- `Saving Multiple Models in One41File <#saving-multiple-models-in-one-file>`__42- `Warmstarting Model Using Parameters from a Different43Model <#warmstarting-model-using-parameters-from-a-different-model>`__44- `Saving & Loading Model Across45Devices <#saving-loading-model-across-devices>`__4647"""484950######################################################################51# What is a ``state_dict``?52# -------------------------53#54# In PyTorch, the learnable parameters (i.e. weights and biases) of an55# ``torch.nn.Module`` model are contained in the model’s *parameters*56# (accessed with ``model.parameters()``). A *state_dict* is simply a57# Python dictionary object that maps each layer to its parameter tensor.58# Note that only layers with learnable parameters (convolutional layers,59# linear layers, etc.) and registered buffers (batchnorm's running_mean)60# have entries in the model’s *state_dict*. Optimizer61# objects (``torch.optim``) also have a *state_dict*, which contains62# information about the optimizer's state, as well as the hyperparameters63# used.64#65# Because *state_dict* objects are Python dictionaries, they can be easily66# saved, updated, altered, and restored, adding a great deal of modularity67# to PyTorch models and optimizers.68#69# Example:70# ^^^^^^^^71#72# Let’s take a look at the *state_dict* from the simple model used in the73# `Training a74# classifier <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py>`__75# tutorial.76#77# .. code:: python78#79# # Define model80# class TheModelClass(nn.Module):81# def __init__(self):82# super(TheModelClass, self).__init__()83# self.conv1 = nn.Conv2d(3, 6, 5)84# self.pool = nn.MaxPool2d(2, 2)85# self.conv2 = nn.Conv2d(6, 16, 5)86# self.fc1 = nn.Linear(16 * 5 * 5, 120)87# self.fc2 = nn.Linear(120, 84)88# self.fc3 = nn.Linear(84, 10)89#90# def forward(self, x):91# x = self.pool(F.relu(self.conv1(x)))92# x = self.pool(F.relu(self.conv2(x)))93# x = x.view(-1, 16 * 5 * 5)94# x = F.relu(self.fc1(x))95# x = F.relu(self.fc2(x))96# x = self.fc3(x)97# return x98#99# # Initialize model100# model = TheModelClass()101#102# # Initialize optimizer103# optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)104#105# # Print model's state_dict106# print("Model's state_dict:")107# for param_tensor in model.state_dict():108# print(param_tensor, "\t", model.state_dict()[param_tensor].size())109#110# # Print optimizer's state_dict111# print("Optimizer's state_dict:")112# for var_name in optimizer.state_dict():113# print(var_name, "\t", optimizer.state_dict()[var_name])114#115# **Output:**116#117# .. code-block:: sh118#119# Model's state_dict:120# conv1.weight torch.Size([6, 3, 5, 5])121# conv1.bias torch.Size([6])122# conv2.weight torch.Size([16, 6, 5, 5])123# conv2.bias torch.Size([16])124# fc1.weight torch.Size([120, 400])125# fc1.bias torch.Size([120])126# fc2.weight torch.Size([84, 120])127# fc2.bias torch.Size([84])128# fc3.weight torch.Size([10, 84])129# fc3.bias torch.Size([10])130#131# Optimizer's state_dict:132# state {}133# param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]134#135136137######################################################################138# Saving & Loading Model for Inference139# ------------------------------------140#141# Save/Load ``state_dict`` (Recommended)142# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^143#144# **Save:**145#146# .. code:: python147#148# torch.save(model.state_dict(), PATH)149#150# **Load:**151#152# .. code:: python153#154# model = TheModelClass(*args, **kwargs)155# model.load_state_dict(torch.load(PATH, weights_only=True))156# model.eval()157#158# .. note::159# The 1.6 release of PyTorch switched ``torch.save`` to use a new160# zip file-based format. ``torch.load`` still retains the ability to161# load files in the old format. If for any reason you want ``torch.save``162# to use the old format, pass the ``kwarg`` parameter ``_use_new_zipfile_serialization=False``.163#164# When saving a model for inference, it is only necessary to save the165# trained model’s learned parameters. Saving the model’s *state_dict* with166# the ``torch.save()`` function will give you the most flexibility for167# restoring the model later, which is why it is the recommended method for168# saving models.169#170# A common PyTorch convention is to save models using either a ``.pt`` or171# ``.pth`` file extension.172#173# Remember that you must call ``model.eval()`` to set dropout and batch174# normalization layers to evaluation mode before running inference.175# Failing to do this will yield inconsistent inference results.176#177# .. note::178#179# Notice that the ``load_state_dict()`` function takes a dictionary180# object, NOT a path to a saved object. This means that you must181# deserialize the saved *state_dict* before you pass it to the182# ``load_state_dict()`` function. For example, you CANNOT load using183# ``model.load_state_dict(PATH)``.184#185# .. note::186#187# If you only plan to keep the best performing model (according to the188# acquired validation loss), don't forget that ``best_model_state = model.state_dict()``189# returns a reference to the state and not its copy! You must serialize190# ``best_model_state`` or use ``best_model_state = deepcopy(model.state_dict())`` otherwise191# your best ``best_model_state`` will keep getting updated by the subsequent training192# iterations. As a result, the final model state will be the state of the overfitted model.193#194# Save/Load Entire Model195# ^^^^^^^^^^^^^^^^^^^^^^196#197# **Save:**198#199# .. code:: python200#201# torch.save(model, PATH)202#203# **Load:**204#205# .. code:: python206#207# # Model class must be defined somewhere208# model = torch.load(PATH, weights_only=False)209# model.eval()210#211# This save/load process uses the most intuitive syntax and involves the212# least amount of code. Saving a model in this way will save the entire213# module using Python’s214# `pickle <https://docs.python.org/3/library/pickle.html>`__ module. The215# disadvantage of this approach is that the serialized data is bound to216# the specific classes and the exact directory structure used when the217# model is saved. The reason for this is because pickle does not save the218# model class itself. Rather, it saves a path to the file containing the219# class, which is used during load time. Because of this, your code can220# break in various ways when used in other projects or after refactors.221#222# A common PyTorch convention is to save models using either a ``.pt`` or223# ``.pth`` file extension.224#225# Remember that you must call ``model.eval()`` to set dropout and batch226# normalization layers to evaluation mode before running inference.227# Failing to do this will yield inconsistent inference results.228#229# Saving an Exported Program230# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~231#232# If you are using ``torch.export``, you can save and load your ``ExportedProgram`` using the233# ``torch.export.save()`` and ``torch.export.load()`` APIs. with the ``.pt2`` file extension:234#235# .. code-block:: python236#237# class SimpleModel(torch.nn.Module):238# def forward(self, x):239# return x + 10240#241# # Create a sample input242# sample_input = torch.randn(5)243#244# # Export the model245# exported_program = torch.export.export(SimpleModel(), sample_input)246#247# # Save the exported program248# torch.export.save(exported_program, 'exported_program.pt2')249#250# # Load the exported program251# saved_exported_program = torch.export.load('exported_program.pt2')252#253254######################################################################255# Saving & Loading a General Checkpoint for Inference and/or Resuming Training256# ----------------------------------------------------------------------------257#258# Save:259# ^^^^^260#261# .. code:: python262#263# torch.save({264# 'epoch': epoch,265# 'model_state_dict': model.state_dict(),266# 'optimizer_state_dict': optimizer.state_dict(),267# 'loss': loss,268# ...269# }, PATH)270#271# Load:272# ^^^^^273#274# .. code:: python275#276# model = TheModelClass(*args, **kwargs)277# optimizer = TheOptimizerClass(*args, **kwargs)278#279# checkpoint = torch.load(PATH, weights_only=True)280# model.load_state_dict(checkpoint['model_state_dict'])281# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])282# epoch = checkpoint['epoch']283# loss = checkpoint['loss']284#285# model.eval()286# # - or -287# model.train()288#289# When saving a general checkpoint, to be used for either inference or290# resuming training, you must save more than just the model’s291# *state_dict*. It is important to also save the optimizer's *state_dict*,292# as this contains buffers and parameters that are updated as the model293# trains. Other items that you may want to save are the epoch you left off294# on, the latest recorded training loss, external ``torch.nn.Embedding``295# layers, etc. As a result, such a checkpoint is often 2~3 times larger296# than the model alone.297#298# To save multiple components, organize them in a dictionary and use299# ``torch.save()`` to serialize the dictionary. A common PyTorch300# convention is to save these checkpoints using the ``.tar`` file301# extension.302#303# To load the items, first initialize the model and optimizer, then load304# the dictionary locally using ``torch.load()``. From here, you can easily305# access the saved items by simply querying the dictionary as you would306# expect.307#308# Remember that you must call ``model.eval()`` to set dropout and batch309# normalization layers to evaluation mode before running inference.310# Failing to do this will yield inconsistent inference results. If you311# wish to resuming training, call ``model.train()`` to ensure these layers312# are in training mode.313#314315316######################################################################317# Saving Multiple Models in One File318# ----------------------------------319#320# Save:321# ^^^^^322#323# .. code:: python324#325# torch.save({326# 'modelA_state_dict': modelA.state_dict(),327# 'modelB_state_dict': modelB.state_dict(),328# 'optimizerA_state_dict': optimizerA.state_dict(),329# 'optimizerB_state_dict': optimizerB.state_dict(),330# ...331# }, PATH)332#333# Load:334# ^^^^^335#336# .. code:: python337#338# modelA = TheModelAClass(*args, **kwargs)339# modelB = TheModelBClass(*args, **kwargs)340# optimizerA = TheOptimizerAClass(*args, **kwargs)341# optimizerB = TheOptimizerBClass(*args, **kwargs)342#343# checkpoint = torch.load(PATH, weights_only=True)344# modelA.load_state_dict(checkpoint['modelA_state_dict'])345# modelB.load_state_dict(checkpoint['modelB_state_dict'])346# optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])347# optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])348#349# modelA.eval()350# modelB.eval()351# # - or -352# modelA.train()353# modelB.train()354#355# When saving a model comprised of multiple ``torch.nn.Modules``, such as356# a GAN, a sequence-to-sequence model, or an ensemble of models, you357# follow the same approach as when you are saving a general checkpoint. In358# other words, save a dictionary of each model’s *state_dict* and359# corresponding optimizer. As mentioned before, you can save any other360# items that may aid you in resuming training by simply appending them to361# the dictionary.362#363# A common PyTorch convention is to save these checkpoints using the364# ``.tar`` file extension.365#366# To load the models, first initialize the models and optimizers, then367# load the dictionary locally using ``torch.load()``. From here, you can368# easily access the saved items by simply querying the dictionary as you369# would expect.370#371# Remember that you must call ``model.eval()`` to set dropout and batch372# normalization layers to evaluation mode before running inference.373# Failing to do this will yield inconsistent inference results. If you374# wish to resuming training, call ``model.train()`` to set these layers to375# training mode.376#377378379######################################################################380# Warmstarting Model Using Parameters from a Different Model381# ----------------------------------------------------------382#383# Save:384# ^^^^^385#386# .. code:: python387#388# torch.save(modelA.state_dict(), PATH)389#390# Load:391# ^^^^^392#393# .. code:: python394#395# modelB = TheModelBClass(*args, **kwargs)396# modelB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)397#398# Partially loading a model or loading a partial model are common399# scenarios when transfer learning or training a new complex model.400# Leveraging trained parameters, even if only a few are usable, will help401# to warmstart the training process and hopefully help your model converge402# much faster than training from scratch.403#404# Whether you are loading from a partial *state_dict*, which is missing405# some keys, or loading a *state_dict* with more keys than the model that406# you are loading into, you can set the ``strict`` argument to **False**407# in the ``load_state_dict()`` function to ignore non-matching keys.408#409# If you want to load parameters from one layer to another, but some keys410# do not match, simply change the name of the parameter keys in the411# *state_dict* that you are loading to match the keys in the model that412# you are loading into.413#414415416######################################################################417# Saving & Loading Model Across Devices418# -------------------------------------419#420# Save on GPU, Load on CPU421# ^^^^^^^^^^^^^^^^^^^^^^^^422#423# **Save:**424#425# .. code:: python426#427# torch.save(model.state_dict(), PATH)428#429# **Load:**430#431# .. code:: python432#433# device = torch.device('cpu')434# model = TheModelClass(*args, **kwargs)435# model.load_state_dict(torch.load(PATH, map_location=device, weights_only=True))436#437# When loading a model on a CPU that was trained with a GPU, pass438# ``torch.device('cpu')`` to the ``map_location`` argument in the439# ``torch.load()`` function. In this case, the storages underlying the440# tensors are dynamically remapped to the CPU device using the441# ``map_location`` argument.442#443# Save on GPU, Load on GPU444# ^^^^^^^^^^^^^^^^^^^^^^^^445#446# **Save:**447#448# .. code:: python449#450# torch.save(model.state_dict(), PATH)451#452# **Load:**453#454# .. code:: python455#456# device = torch.device("cuda")457# model = TheModelClass(*args, **kwargs)458# model.load_state_dict(torch.load(PATH, weights_only=True))459# model.to(device)460# # Make sure to call input = input.to(device) on any input tensors that you feed to the model461#462# When loading a model on a GPU that was trained and saved on GPU, simply463# convert the initialized ``model`` to a CUDA optimized model using464# ``model.to(torch.device('cuda'))``. Also, be sure to use the465# ``.to(torch.device('cuda'))`` function on all model inputs to prepare466# the data for the model. Note that calling ``my_tensor.to(device)``467# returns a new copy of ``my_tensor`` on GPU. It does NOT overwrite468# ``my_tensor``. Therefore, remember to manually overwrite tensors:469# ``my_tensor = my_tensor.to(torch.device('cuda'))``.470#471# Save on CPU, Load on GPU472# ^^^^^^^^^^^^^^^^^^^^^^^^473#474# **Save:**475#476# .. code:: python477#478# torch.save(model.state_dict(), PATH)479#480# **Load:**481#482# .. code:: python483#484# device = torch.device("cuda")485# model = TheModelClass(*args, **kwargs)486# model.load_state_dict(torch.load(PATH, weights_only=True, map_location="cuda:0")) # Choose whatever GPU device number you want487# model.to(device)488# # Make sure to call input = input.to(device) on any input tensors that you feed to the model489#490# When loading a model on a GPU that was trained and saved on CPU, set the491# ``map_location`` argument in the ``torch.load()`` function to492# ``cuda:device_id``. This loads the model to a given GPU device. Next, be493# sure to call ``model.to(torch.device('cuda'))`` to convert the model’s494# parameter tensors to CUDA tensors. Finally, be sure to use the495# ``.to(torch.device('cuda'))`` function on all model inputs to prepare496# the data for the CUDA optimized model. Note that calling497# ``my_tensor.to(device)`` returns a new copy of ``my_tensor`` on GPU. It498# does NOT overwrite ``my_tensor``. Therefore, remember to manually499# overwrite tensors: ``my_tensor = my_tensor.to(torch.device('cuda'))``.500#501# Saving ``torch.nn.DataParallel`` Models502# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^503#504# **Save:**505#506# .. code:: python507#508# torch.save(model.module.state_dict(), PATH)509#510# **Load:**511#512# .. code:: python513#514# # Load to whatever device you want515#516# ``torch.nn.DataParallel`` is a model wrapper that enables parallel GPU517# utilization. To save a ``DataParallel`` model generically, save the518# ``model.module.state_dict()``. This way, you have the flexibility to519# load the model any way you want to any device you want.520#521522523