Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/saving_loading_models.py
1686 views
1
# -*- coding: utf-8 -*-
2
"""
3
Saving and Loading Models
4
=========================
5
**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_
6
7
This document provides solutions to a variety of use cases regarding the
8
saving and loading of PyTorch models. Feel free to read the whole
9
document, or just skip to the code you need for a desired use case.
10
11
When it comes to saving and loading models, there are three core
12
functions to be familiar with:
13
14
1) `torch.save <https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save>`__:
15
Saves a serialized object to disk. This function uses Python’s
16
`pickle <https://docs.python.org/3/library/pickle.html>`__ utility
17
for serialization. Models, tensors, and dictionaries of all kinds of
18
objects can be saved using this function.
19
20
2) `torch.load <https://pytorch.org/docs/stable/torch.html?highlight=torch%20load#torch.load>`__:
21
Uses `pickle <https://docs.python.org/3/library/pickle.html>`__\ ’s
22
unpickling facilities to deserialize pickled object files to memory.
23
This function also facilitates the device to load the data into (see
24
`Saving & Loading Model Across
25
Devices <#saving-loading-model-across-devices>`__).
26
27
3) `torch.nn.Module.load_state_dict <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict>`__:
28
Loads a model’s parameter dictionary using a deserialized
29
*state_dict*. For more information on *state_dict*, see `What is a
30
state_dict? <#what-is-a-state-dict>`__.
31
32
33
34
**Contents:**
35
36
- `What is a state_dict? <#what-is-a-state-dict>`__
37
- `Saving & Loading Model for
38
Inference <#saving-loading-model-for-inference>`__
39
- `Saving & Loading a General
40
Checkpoint <#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training>`__
41
- `Saving Multiple Models in One
42
File <#saving-multiple-models-in-one-file>`__
43
- `Warmstarting Model Using Parameters from a Different
44
Model <#warmstarting-model-using-parameters-from-a-different-model>`__
45
- `Saving & Loading Model Across
46
Devices <#saving-loading-model-across-devices>`__
47
48
"""
49
50
51
######################################################################
52
# What is a ``state_dict``?
53
# -------------------------
54
#
55
# In PyTorch, the learnable parameters (i.e. weights and biases) of an
56
# ``torch.nn.Module`` model are contained in the model’s *parameters*
57
# (accessed with ``model.parameters()``). A *state_dict* is simply a
58
# Python dictionary object that maps each layer to its parameter tensor.
59
# Note that only layers with learnable parameters (convolutional layers,
60
# linear layers, etc.) and registered buffers (batchnorm's running_mean)
61
# have entries in the model’s *state_dict*. Optimizer
62
# objects (``torch.optim``) also have a *state_dict*, which contains
63
# information about the optimizer's state, as well as the hyperparameters
64
# used.
65
#
66
# Because *state_dict* objects are Python dictionaries, they can be easily
67
# saved, updated, altered, and restored, adding a great deal of modularity
68
# to PyTorch models and optimizers.
69
#
70
# Example:
71
# ^^^^^^^^
72
#
73
# Let’s take a look at the *state_dict* from the simple model used in the
74
# `Training a
75
# classifier <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py>`__
76
# tutorial.
77
#
78
# .. code:: python
79
#
80
# # Define model
81
# class TheModelClass(nn.Module):
82
# def __init__(self):
83
# super(TheModelClass, self).__init__()
84
# self.conv1 = nn.Conv2d(3, 6, 5)
85
# self.pool = nn.MaxPool2d(2, 2)
86
# self.conv2 = nn.Conv2d(6, 16, 5)
87
# self.fc1 = nn.Linear(16 * 5 * 5, 120)
88
# self.fc2 = nn.Linear(120, 84)
89
# self.fc3 = nn.Linear(84, 10)
90
#
91
# def forward(self, x):
92
# x = self.pool(F.relu(self.conv1(x)))
93
# x = self.pool(F.relu(self.conv2(x)))
94
# x = x.view(-1, 16 * 5 * 5)
95
# x = F.relu(self.fc1(x))
96
# x = F.relu(self.fc2(x))
97
# x = self.fc3(x)
98
# return x
99
#
100
# # Initialize model
101
# model = TheModelClass()
102
#
103
# # Initialize optimizer
104
# optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
105
#
106
# # Print model's state_dict
107
# print("Model's state_dict:")
108
# for param_tensor in model.state_dict():
109
# print(param_tensor, "\t", model.state_dict()[param_tensor].size())
110
#
111
# # Print optimizer's state_dict
112
# print("Optimizer's state_dict:")
113
# for var_name in optimizer.state_dict():
114
# print(var_name, "\t", optimizer.state_dict()[var_name])
115
#
116
# **Output:**
117
#
118
# .. code-block:: sh
119
#
120
# Model's state_dict:
121
# conv1.weight torch.Size([6, 3, 5, 5])
122
# conv1.bias torch.Size([6])
123
# conv2.weight torch.Size([16, 6, 5, 5])
124
# conv2.bias torch.Size([16])
125
# fc1.weight torch.Size([120, 400])
126
# fc1.bias torch.Size([120])
127
# fc2.weight torch.Size([84, 120])
128
# fc2.bias torch.Size([84])
129
# fc3.weight torch.Size([10, 84])
130
# fc3.bias torch.Size([10])
131
#
132
# Optimizer's state_dict:
133
# state {}
134
# param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
135
#
136
137
138
######################################################################
139
# Saving & Loading Model for Inference
140
# ------------------------------------
141
#
142
# Save/Load ``state_dict`` (Recommended)
143
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
144
#
145
# **Save:**
146
#
147
# .. code:: python
148
#
149
# torch.save(model.state_dict(), PATH)
150
#
151
# **Load:**
152
#
153
# .. code:: python
154
#
155
# model = TheModelClass(*args, **kwargs)
156
# model.load_state_dict(torch.load(PATH, weights_only=True))
157
# model.eval()
158
#
159
# .. note::
160
# The 1.6 release of PyTorch switched ``torch.save`` to use a new
161
# zip file-based format. ``torch.load`` still retains the ability to
162
# load files in the old format. If for any reason you want ``torch.save``
163
# to use the old format, pass the ``kwarg`` parameter ``_use_new_zipfile_serialization=False``.
164
#
165
# When saving a model for inference, it is only necessary to save the
166
# trained model’s learned parameters. Saving the model’s *state_dict* with
167
# the ``torch.save()`` function will give you the most flexibility for
168
# restoring the model later, which is why it is the recommended method for
169
# saving models.
170
#
171
# A common PyTorch convention is to save models using either a ``.pt`` or
172
# ``.pth`` file extension.
173
#
174
# Remember that you must call ``model.eval()`` to set dropout and batch
175
# normalization layers to evaluation mode before running inference.
176
# Failing to do this will yield inconsistent inference results.
177
#
178
# .. note::
179
#
180
# Notice that the ``load_state_dict()`` function takes a dictionary
181
# object, NOT a path to a saved object. This means that you must
182
# deserialize the saved *state_dict* before you pass it to the
183
# ``load_state_dict()`` function. For example, you CANNOT load using
184
# ``model.load_state_dict(PATH)``.
185
#
186
# .. note::
187
#
188
# If you only plan to keep the best performing model (according to the
189
# acquired validation loss), don't forget that ``best_model_state = model.state_dict()``
190
# returns a reference to the state and not its copy! You must serialize
191
# ``best_model_state`` or use ``best_model_state = deepcopy(model.state_dict())`` otherwise
192
# your best ``best_model_state`` will keep getting updated by the subsequent training
193
# iterations. As a result, the final model state will be the state of the overfitted model.
194
#
195
# Save/Load Entire Model
196
# ^^^^^^^^^^^^^^^^^^^^^^
197
#
198
# **Save:**
199
#
200
# .. code:: python
201
#
202
# torch.save(model, PATH)
203
#
204
# **Load:**
205
#
206
# .. code:: python
207
#
208
# # Model class must be defined somewhere
209
# model = torch.load(PATH, weights_only=False)
210
# model.eval()
211
#
212
# This save/load process uses the most intuitive syntax and involves the
213
# least amount of code. Saving a model in this way will save the entire
214
# module using Python’s
215
# `pickle <https://docs.python.org/3/library/pickle.html>`__ module. The
216
# disadvantage of this approach is that the serialized data is bound to
217
# the specific classes and the exact directory structure used when the
218
# model is saved. The reason for this is because pickle does not save the
219
# model class itself. Rather, it saves a path to the file containing the
220
# class, which is used during load time. Because of this, your code can
221
# break in various ways when used in other projects or after refactors.
222
#
223
# A common PyTorch convention is to save models using either a ``.pt`` or
224
# ``.pth`` file extension.
225
#
226
# Remember that you must call ``model.eval()`` to set dropout and batch
227
# normalization layers to evaluation mode before running inference.
228
# Failing to do this will yield inconsistent inference results.
229
#
230
# Saving an Exported Program
231
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
232
#
233
# If you are using ``torch.export``, you can save and load your ``ExportedProgram`` using the
234
# ``torch.export.save()`` and ``torch.export.load()`` APIs. with the ``.pt2`` file extension:
235
#
236
# .. code-block:: python
237
#
238
# class SimpleModel(torch.nn.Module):
239
# def forward(self, x):
240
# return x + 10
241
#
242
# # Create a sample input
243
# sample_input = torch.randn(5)
244
#
245
# # Export the model
246
# exported_program = torch.export.export(SimpleModel(), sample_input)
247
#
248
# # Save the exported program
249
# torch.export.save(exported_program, 'exported_program.pt2')
250
#
251
# # Load the exported program
252
# saved_exported_program = torch.export.load('exported_program.pt2')
253
#
254
255
######################################################################
256
# Saving & Loading a General Checkpoint for Inference and/or Resuming Training
257
# ----------------------------------------------------------------------------
258
#
259
# Save:
260
# ^^^^^
261
#
262
# .. code:: python
263
#
264
# torch.save({
265
# 'epoch': epoch,
266
# 'model_state_dict': model.state_dict(),
267
# 'optimizer_state_dict': optimizer.state_dict(),
268
# 'loss': loss,
269
# ...
270
# }, PATH)
271
#
272
# Load:
273
# ^^^^^
274
#
275
# .. code:: python
276
#
277
# model = TheModelClass(*args, **kwargs)
278
# optimizer = TheOptimizerClass(*args, **kwargs)
279
#
280
# checkpoint = torch.load(PATH, weights_only=True)
281
# model.load_state_dict(checkpoint['model_state_dict'])
282
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
283
# epoch = checkpoint['epoch']
284
# loss = checkpoint['loss']
285
#
286
# model.eval()
287
# # - or -
288
# model.train()
289
#
290
# When saving a general checkpoint, to be used for either inference or
291
# resuming training, you must save more than just the model’s
292
# *state_dict*. It is important to also save the optimizer's *state_dict*,
293
# as this contains buffers and parameters that are updated as the model
294
# trains. Other items that you may want to save are the epoch you left off
295
# on, the latest recorded training loss, external ``torch.nn.Embedding``
296
# layers, etc. As a result, such a checkpoint is often 2~3 times larger
297
# than the model alone.
298
#
299
# To save multiple components, organize them in a dictionary and use
300
# ``torch.save()`` to serialize the dictionary. A common PyTorch
301
# convention is to save these checkpoints using the ``.tar`` file
302
# extension.
303
#
304
# To load the items, first initialize the model and optimizer, then load
305
# the dictionary locally using ``torch.load()``. From here, you can easily
306
# access the saved items by simply querying the dictionary as you would
307
# expect.
308
#
309
# Remember that you must call ``model.eval()`` to set dropout and batch
310
# normalization layers to evaluation mode before running inference.
311
# Failing to do this will yield inconsistent inference results. If you
312
# wish to resuming training, call ``model.train()`` to ensure these layers
313
# are in training mode.
314
#
315
316
317
######################################################################
318
# Saving Multiple Models in One File
319
# ----------------------------------
320
#
321
# Save:
322
# ^^^^^
323
#
324
# .. code:: python
325
#
326
# torch.save({
327
# 'modelA_state_dict': modelA.state_dict(),
328
# 'modelB_state_dict': modelB.state_dict(),
329
# 'optimizerA_state_dict': optimizerA.state_dict(),
330
# 'optimizerB_state_dict': optimizerB.state_dict(),
331
# ...
332
# }, PATH)
333
#
334
# Load:
335
# ^^^^^
336
#
337
# .. code:: python
338
#
339
# modelA = TheModelAClass(*args, **kwargs)
340
# modelB = TheModelBClass(*args, **kwargs)
341
# optimizerA = TheOptimizerAClass(*args, **kwargs)
342
# optimizerB = TheOptimizerBClass(*args, **kwargs)
343
#
344
# checkpoint = torch.load(PATH, weights_only=True)
345
# modelA.load_state_dict(checkpoint['modelA_state_dict'])
346
# modelB.load_state_dict(checkpoint['modelB_state_dict'])
347
# optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
348
# optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
349
#
350
# modelA.eval()
351
# modelB.eval()
352
# # - or -
353
# modelA.train()
354
# modelB.train()
355
#
356
# When saving a model comprised of multiple ``torch.nn.Modules``, such as
357
# a GAN, a sequence-to-sequence model, or an ensemble of models, you
358
# follow the same approach as when you are saving a general checkpoint. In
359
# other words, save a dictionary of each model’s *state_dict* and
360
# corresponding optimizer. As mentioned before, you can save any other
361
# items that may aid you in resuming training by simply appending them to
362
# the dictionary.
363
#
364
# A common PyTorch convention is to save these checkpoints using the
365
# ``.tar`` file extension.
366
#
367
# To load the models, first initialize the models and optimizers, then
368
# load the dictionary locally using ``torch.load()``. From here, you can
369
# easily access the saved items by simply querying the dictionary as you
370
# would expect.
371
#
372
# Remember that you must call ``model.eval()`` to set dropout and batch
373
# normalization layers to evaluation mode before running inference.
374
# Failing to do this will yield inconsistent inference results. If you
375
# wish to resuming training, call ``model.train()`` to set these layers to
376
# training mode.
377
#
378
379
380
######################################################################
381
# Warmstarting Model Using Parameters from a Different Model
382
# ----------------------------------------------------------
383
#
384
# Save:
385
# ^^^^^
386
#
387
# .. code:: python
388
#
389
# torch.save(modelA.state_dict(), PATH)
390
#
391
# Load:
392
# ^^^^^
393
#
394
# .. code:: python
395
#
396
# modelB = TheModelBClass(*args, **kwargs)
397
# modelB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)
398
#
399
# Partially loading a model or loading a partial model are common
400
# scenarios when transfer learning or training a new complex model.
401
# Leveraging trained parameters, even if only a few are usable, will help
402
# to warmstart the training process and hopefully help your model converge
403
# much faster than training from scratch.
404
#
405
# Whether you are loading from a partial *state_dict*, which is missing
406
# some keys, or loading a *state_dict* with more keys than the model that
407
# you are loading into, you can set the ``strict`` argument to **False**
408
# in the ``load_state_dict()`` function to ignore non-matching keys.
409
#
410
# If you want to load parameters from one layer to another, but some keys
411
# do not match, simply change the name of the parameter keys in the
412
# *state_dict* that you are loading to match the keys in the model that
413
# you are loading into.
414
#
415
416
417
######################################################################
418
# Saving & Loading Model Across Devices
419
# -------------------------------------
420
#
421
# Save on GPU, Load on CPU
422
# ^^^^^^^^^^^^^^^^^^^^^^^^
423
#
424
# **Save:**
425
#
426
# .. code:: python
427
#
428
# torch.save(model.state_dict(), PATH)
429
#
430
# **Load:**
431
#
432
# .. code:: python
433
#
434
# device = torch.device('cpu')
435
# model = TheModelClass(*args, **kwargs)
436
# model.load_state_dict(torch.load(PATH, map_location=device, weights_only=True))
437
#
438
# When loading a model on a CPU that was trained with a GPU, pass
439
# ``torch.device('cpu')`` to the ``map_location`` argument in the
440
# ``torch.load()`` function. In this case, the storages underlying the
441
# tensors are dynamically remapped to the CPU device using the
442
# ``map_location`` argument.
443
#
444
# Save on GPU, Load on GPU
445
# ^^^^^^^^^^^^^^^^^^^^^^^^
446
#
447
# **Save:**
448
#
449
# .. code:: python
450
#
451
# torch.save(model.state_dict(), PATH)
452
#
453
# **Load:**
454
#
455
# .. code:: python
456
#
457
# device = torch.device("cuda")
458
# model = TheModelClass(*args, **kwargs)
459
# model.load_state_dict(torch.load(PATH, weights_only=True))
460
# model.to(device)
461
# # Make sure to call input = input.to(device) on any input tensors that you feed to the model
462
#
463
# When loading a model on a GPU that was trained and saved on GPU, simply
464
# convert the initialized ``model`` to a CUDA optimized model using
465
# ``model.to(torch.device('cuda'))``. Also, be sure to use the
466
# ``.to(torch.device('cuda'))`` function on all model inputs to prepare
467
# the data for the model. Note that calling ``my_tensor.to(device)``
468
# returns a new copy of ``my_tensor`` on GPU. It does NOT overwrite
469
# ``my_tensor``. Therefore, remember to manually overwrite tensors:
470
# ``my_tensor = my_tensor.to(torch.device('cuda'))``.
471
#
472
# Save on CPU, Load on GPU
473
# ^^^^^^^^^^^^^^^^^^^^^^^^
474
#
475
# **Save:**
476
#
477
# .. code:: python
478
#
479
# torch.save(model.state_dict(), PATH)
480
#
481
# **Load:**
482
#
483
# .. code:: python
484
#
485
# device = torch.device("cuda")
486
# model = TheModelClass(*args, **kwargs)
487
# model.load_state_dict(torch.load(PATH, weights_only=True, map_location="cuda:0")) # Choose whatever GPU device number you want
488
# model.to(device)
489
# # Make sure to call input = input.to(device) on any input tensors that you feed to the model
490
#
491
# When loading a model on a GPU that was trained and saved on CPU, set the
492
# ``map_location`` argument in the ``torch.load()`` function to
493
# ``cuda:device_id``. This loads the model to a given GPU device. Next, be
494
# sure to call ``model.to(torch.device('cuda'))`` to convert the model’s
495
# parameter tensors to CUDA tensors. Finally, be sure to use the
496
# ``.to(torch.device('cuda'))`` function on all model inputs to prepare
497
# the data for the CUDA optimized model. Note that calling
498
# ``my_tensor.to(device)`` returns a new copy of ``my_tensor`` on GPU. It
499
# does NOT overwrite ``my_tensor``. Therefore, remember to manually
500
# overwrite tensors: ``my_tensor = my_tensor.to(torch.device('cuda'))``.
501
#
502
# Saving ``torch.nn.DataParallel`` Models
503
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
504
#
505
# **Save:**
506
#
507
# .. code:: python
508
#
509
# torch.save(model.module.state_dict(), PATH)
510
#
511
# **Load:**
512
#
513
# .. code:: python
514
#
515
# # Load to whatever device you want
516
#
517
# ``torch.nn.DataParallel`` is a model wrapper that enables parallel GPU
518
# utilization. To save a ``DataParallel`` model generically, save the
519
# ``model.module.state_dict()``. This way, you have the flexibility to
520
# load the model any way you want to any device you want.
521
#
522
523