Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/serialization_and_saving.py
3273 views
1
"""
2
Title: Save, serialize, and export models
3
Authors: Neel Kovelamudi, Francois Chollet
4
Date created: 2023/06/14
5
Last modified: 2023/06/30
6
Description: Complete guide to saving, serializing, and exporting models.
7
Accelerator: None
8
"""
9
10
"""
11
## Introduction
12
13
A Keras model consists of multiple components:
14
15
- The architecture, or configuration, which specifies what layers the model
16
contain, and how they're connected.
17
- A set of weights values (the "state of the model").
18
- An optimizer (defined by compiling the model).
19
- A set of losses and metrics (defined by compiling the model).
20
21
The Keras API saves all of these pieces together in a unified format,
22
marked by the `.keras` extension. This is a zip archive consisting of the
23
following:
24
25
- A JSON-based configuration file (config.json): Records of model, layer, and
26
other trackables' configuration.
27
- A H5-based state file, such as `model.weights.h5` (for the whole model),
28
with directory keys for layers and their weights.
29
- A metadata file in JSON, storing things such as the current Keras version.
30
31
Let's take a look at how this works.
32
"""
33
34
"""
35
## How to save and load a model
36
37
If you only have 10 seconds to read this guide, here's what you need to know.
38
39
**Saving a Keras model:**
40
41
```python
42
model = ... # Get model (Sequential, Functional Model, or Model subclass)
43
model.save('path/to/location.keras') # The file needs to end with the .keras extension
44
```
45
46
**Loading the model back:**
47
48
```python
49
model = keras.models.load_model('path/to/location.keras')
50
```
51
52
Now, let's look at the details.
53
"""
54
55
"""
56
## Setup
57
"""
58
59
import numpy as np
60
import keras
61
from keras import ops
62
63
"""
64
## Saving
65
66
This section is about saving an entire model to a single file. The file will include:
67
68
- The model's architecture/config
69
- The model's weight values (which were learned during training)
70
- The model's compilation information (if `compile()` was called)
71
- The optimizer and its state, if any (this enables you to restart training
72
where you left)
73
74
#### APIs
75
76
You can save a model with `model.save()` or `keras.models.save_model()` (which is equivalent).
77
You can load it back with `keras.models.load_model()`.
78
79
The only supported format in Keras 3 is the "Keras v3" format,
80
which uses the `.keras` extension.
81
82
**Example:**
83
"""
84
85
86
def get_model():
87
# Create a simple model.
88
inputs = keras.Input(shape=(32,))
89
outputs = keras.layers.Dense(1)(inputs)
90
model = keras.Model(inputs, outputs)
91
model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error")
92
return model
93
94
95
model = get_model()
96
97
# Train the model.
98
test_input = np.random.random((128, 32))
99
test_target = np.random.random((128, 1))
100
model.fit(test_input, test_target)
101
102
# Calling `save('my_model.keras')` creates a zip archive `my_model.keras`.
103
model.save("my_model.keras")
104
105
# It can be used to reconstruct the model identically.
106
reconstructed_model = keras.models.load_model("my_model.keras")
107
108
# Let's check:
109
np.testing.assert_allclose(
110
model.predict(test_input), reconstructed_model.predict(test_input)
111
)
112
113
"""
114
### Custom objects
115
116
This section covers the basic workflows for handling custom layers, functions, and
117
models in Keras saving and reloading.
118
119
When saving a model that includes custom objects, such as a subclassed Layer,
120
you **must** define a `get_config()` method on the object class.
121
If the arguments passed to the constructor (`__init__()` method) of the custom object
122
aren't Python objects (anything other than base types like ints, strings,
123
etc.), then you **must** also explicitly deserialize these arguments in the `from_config()`
124
class method.
125
126
Like this:
127
128
```python
129
class CustomLayer(keras.layers.Layer):
130
def __init__(self, sublayer, **kwargs):
131
super().__init__(**kwargs)
132
self.sublayer = sublayer
133
134
def call(self, x):
135
return self.sublayer(x)
136
137
def get_config(self):
138
base_config = super().get_config()
139
config = {
140
"sublayer": keras.saving.serialize_keras_object(self.sublayer),
141
}
142
return {**base_config, **config}
143
144
@classmethod
145
def from_config(cls, config):
146
sublayer_config = config.pop("sublayer")
147
sublayer = keras.saving.deserialize_keras_object(sublayer_config)
148
return cls(sublayer, **config)
149
```
150
151
Please see the [Defining the config methods section](#config_methods) for more
152
details and examples.
153
154
The saved `.keras` file is lightweight and does not store the Python code for custom
155
objects. Therefore, to reload the model, `load_model` requires access to the definition
156
of any custom objects used through one of the following methods:
157
158
1. Registering custom objects **(preferred)**,
159
2. Passing custom objects directly when loading, or
160
3. Using a custom object scope
161
162
Below are examples of each workflow:
163
164
#### Registering custom objects (**preferred**)
165
166
This is the preferred method, as custom object registration greatly simplifies saving and
167
loading code. Adding the `@keras.saving.register_keras_serializable` decorator to the
168
class definition of a custom object registers the object globally in a master list,
169
allowing Keras to recognize the object when loading the model.
170
171
Let's create a custom model involving both a custom layer and a custom activation
172
function to demonstrate this.
173
174
**Example:**
175
"""
176
177
# Clear all previously registered custom objects
178
keras.saving.get_custom_objects().clear()
179
180
181
# Upon registration, you can optionally specify a package or a name.
182
# If left blank, the package defaults to `Custom` and the name defaults to
183
# the class name.
184
@keras.saving.register_keras_serializable(package="MyLayers")
185
class CustomLayer(keras.layers.Layer):
186
def __init__(self, factor):
187
super().__init__()
188
self.factor = factor
189
190
def call(self, x):
191
return x * self.factor
192
193
def get_config(self):
194
return {"factor": self.factor}
195
196
197
@keras.saving.register_keras_serializable(package="my_package", name="custom_fn")
198
def custom_fn(x):
199
return x**2
200
201
202
# Create the model.
203
def get_model():
204
inputs = keras.Input(shape=(4,))
205
mid = CustomLayer(0.5)(inputs)
206
outputs = keras.layers.Dense(1, activation=custom_fn)(mid)
207
model = keras.Model(inputs, outputs)
208
model.compile(optimizer="rmsprop", loss="mean_squared_error")
209
return model
210
211
212
# Train the model.
213
def train_model(model):
214
input = np.random.random((4, 4))
215
target = np.random.random((4, 1))
216
model.fit(input, target)
217
return model
218
219
220
test_input = np.random.random((4, 4))
221
test_target = np.random.random((4, 1))
222
223
model = get_model()
224
model = train_model(model)
225
model.save("custom_model.keras")
226
227
# Now, we can simply load without worrying about our custom objects.
228
reconstructed_model = keras.models.load_model("custom_model.keras")
229
230
# Let's check:
231
np.testing.assert_allclose(
232
model.predict(test_input), reconstructed_model.predict(test_input)
233
)
234
235
"""
236
#### Passing custom objects to `load_model()`
237
"""
238
239
model = get_model()
240
model = train_model(model)
241
242
# Calling `save('my_model.keras')` creates a zip archive `my_model.keras`.
243
model.save("custom_model.keras")
244
245
# Upon loading, pass a dict containing the custom objects used in the
246
# `custom_objects` argument of `keras.models.load_model()`.
247
reconstructed_model = keras.models.load_model(
248
"custom_model.keras",
249
custom_objects={"CustomLayer": CustomLayer, "custom_fn": custom_fn},
250
)
251
252
# Let's check:
253
np.testing.assert_allclose(
254
model.predict(test_input), reconstructed_model.predict(test_input)
255
)
256
257
258
"""
259
#### Using a custom object scope
260
261
Any code within the custom object scope will be able to recognize the custom objects
262
passed to the scope argument. Therefore, loading the model within the scope will allow
263
the loading of our custom objects.
264
265
**Example:**
266
"""
267
268
model = get_model()
269
model = train_model(model)
270
model.save("custom_model.keras")
271
272
# Pass the custom objects dictionary to a custom object scope and place
273
# the `keras.models.load_model()` call within the scope.
274
custom_objects = {"CustomLayer": CustomLayer, "custom_fn": custom_fn}
275
276
with keras.saving.custom_object_scope(custom_objects):
277
reconstructed_model = keras.models.load_model("custom_model.keras")
278
279
# Let's check:
280
np.testing.assert_allclose(
281
model.predict(test_input), reconstructed_model.predict(test_input)
282
)
283
284
"""
285
### Model serialization
286
287
This section is about saving only the model's configuration, without its state.
288
The model's configuration (or architecture) specifies what layers the model
289
contains, and how these layers are connected. If you have the configuration of a model,
290
then the model can be created with a freshly initialized state (no weights or compilation
291
information).
292
293
#### APIs
294
295
The following serialization APIs are available:
296
297
- `keras.models.clone_model(model)`: make a (randomly initialized) copy of a model.
298
- `get_config()` and `cls.from_config()`: retrieve the configuration of a layer or model, and recreate
299
a model instance from its config, respectively.
300
- `keras.models.model_to_json()` and `keras.models.model_from_json()`: similar, but as JSON strings.
301
- `keras.saving.serialize_keras_object()`: retrieve the configuration any arbitrary Keras object.
302
- `keras.saving.deserialize_keras_object()`: recreate an object instance from its configuration.
303
304
#### In-memory model cloning
305
306
You can do in-memory cloning of a model via `keras.models.clone_model()`.
307
This is equivalent to getting the config then recreating the model from its config
308
(so it does not preserve compilation information or layer weights values).
309
310
**Example:**
311
"""
312
313
new_model = keras.models.clone_model(model)
314
315
"""
316
#### `get_config()` and `from_config()`
317
318
Calling `model.get_config()` or `layer.get_config()` will return a Python dict containing
319
the configuration of the model or layer, respectively. You should define `get_config()`
320
to contain arguments needed for the `__init__()` method of the model or layer. At loading time,
321
the `from_config(config)` method will then call `__init__()` with these arguments to
322
reconstruct the model or layer.
323
324
325
**Layer example:**
326
"""
327
328
layer = keras.layers.Dense(3, activation="relu")
329
layer_config = layer.get_config()
330
print(layer_config)
331
332
"""
333
Now let's reconstruct the layer using the `from_config()` method:
334
"""
335
336
new_layer = keras.layers.Dense.from_config(layer_config)
337
338
"""
339
**Sequential model example:**
340
"""
341
342
model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])
343
config = model.get_config()
344
new_model = keras.Sequential.from_config(config)
345
346
"""
347
**Functional model example:**
348
"""
349
350
inputs = keras.Input((32,))
351
outputs = keras.layers.Dense(1)(inputs)
352
model = keras.Model(inputs, outputs)
353
config = model.get_config()
354
new_model = keras.Model.from_config(config)
355
356
"""
357
#### `to_json()` and `keras.models.model_from_json()`
358
359
This is similar to `get_config` / `from_config`, except it turns the model
360
into a JSON string, which can then be loaded without the original model class.
361
It is also specific to models, it isn't meant for layers.
362
363
**Example:**
364
"""
365
366
model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])
367
json_config = model.to_json()
368
new_model = keras.models.model_from_json(json_config)
369
370
371
"""
372
#### Arbitrary object serialization and deserialization
373
374
The `keras.saving.serialize_keras_object()` and `keras.saving.deserialize_keras_object()`
375
APIs are general-purpose APIs that can be used to serialize or deserialize any Keras
376
object and any custom object. It is at the foundation of saving model architecture and is
377
behind all `serialize()`/`deserialize()` calls in keras.
378
379
**Example**:
380
"""
381
382
my_reg = keras.regularizers.L1(0.005)
383
config = keras.saving.serialize_keras_object(my_reg)
384
print(config)
385
386
"""
387
Note the serialization format containing all the necessary information for proper
388
reconstruction:
389
390
- `module` containing the name of the Keras module or other identifying module the object
391
comes from
392
- `class_name` containing the name of the object's class.
393
- `config` with all the information needed to reconstruct the object
394
- `registered_name` for custom objects. See [here](#custom_object_serialization).
395
396
Now we can reconstruct the regularizer.
397
"""
398
399
new_reg = keras.saving.deserialize_keras_object(config)
400
401
"""
402
### Model weights saving
403
404
You can choose to only save & load a model's weights. This can be useful if:
405
406
- You only need the model for inference: in this case you won't need to
407
restart training, so you don't need the compilation information or optimizer state.
408
- You are doing transfer learning: in this case you will be training a new model
409
reusing the state of a prior model, so you don't need the compilation
410
information of the prior model.
411
412
#### APIs for in-memory weight transfer
413
414
Weights can be copied between different objects by using `get_weights()`
415
and `set_weights()`:
416
417
* `keras.layers.Layer.get_weights()`: Returns a list of NumPy arrays of weight values.
418
* `keras.layers.Layer.set_weights(weights)`: Sets the model weights to the values
419
provided (as NumPy arrays).
420
421
Examples:
422
423
***Transferring weights from one layer to another, in memory***
424
"""
425
426
427
def create_layer():
428
layer = keras.layers.Dense(64, activation="relu", name="dense_2")
429
layer.build((None, 784))
430
return layer
431
432
433
layer_1 = create_layer()
434
layer_2 = create_layer()
435
436
# Copy weights from layer 1 to layer 2
437
layer_2.set_weights(layer_1.get_weights())
438
439
"""
440
***Transferring weights from one model to another model with a compatible architecture, in memory***
441
"""
442
443
# Create a simple functional model
444
inputs = keras.Input(shape=(784,), name="digits")
445
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
446
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
447
outputs = keras.layers.Dense(10, name="predictions")(x)
448
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")
449
450
451
# Define a subclassed model with the same architecture
452
class SubclassedModel(keras.Model):
453
def __init__(self, output_dim, name=None):
454
super().__init__(name=name)
455
self.output_dim = output_dim
456
self.dense_1 = keras.layers.Dense(64, activation="relu", name="dense_1")
457
self.dense_2 = keras.layers.Dense(64, activation="relu", name="dense_2")
458
self.dense_3 = keras.layers.Dense(output_dim, name="predictions")
459
460
def call(self, inputs):
461
x = self.dense_1(inputs)
462
x = self.dense_2(x)
463
x = self.dense_3(x)
464
return x
465
466
def get_config(self):
467
return {"output_dim": self.output_dim, "name": self.name}
468
469
470
subclassed_model = SubclassedModel(10)
471
# Call the subclassed model once to create the weights.
472
subclassed_model(np.ones((1, 784)))
473
474
# Copy weights from functional_model to subclassed_model.
475
subclassed_model.set_weights(functional_model.get_weights())
476
477
assert len(functional_model.weights) == len(subclassed_model.weights)
478
for a, b in zip(functional_model.weights, subclassed_model.weights):
479
np.testing.assert_allclose(a.numpy(), b.numpy())
480
481
"""
482
***The case of stateless layers***
483
484
Because stateless layers do not change the order or number of weights,
485
models can have compatible architectures even if there are extra/missing
486
stateless layers.
487
"""
488
489
inputs = keras.Input(shape=(784,), name="digits")
490
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
491
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
492
outputs = keras.layers.Dense(10, name="predictions")(x)
493
functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")
494
495
inputs = keras.Input(shape=(784,), name="digits")
496
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
497
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
498
499
# Add a dropout layer, which does not contain any weights.
500
x = keras.layers.Dropout(0.5)(x)
501
outputs = keras.layers.Dense(10, name="predictions")(x)
502
functional_model_with_dropout = keras.Model(
503
inputs=inputs, outputs=outputs, name="3_layer_mlp"
504
)
505
506
functional_model_with_dropout.set_weights(functional_model.get_weights())
507
508
"""
509
#### APIs for saving weights to disk & loading them back
510
511
Weights can be saved to disk by calling `model.save_weights(filepath)`.
512
The filename should end in `.weights.h5`.
513
514
**Example:**
515
"""
516
517
# Runnable example
518
sequential_model = keras.Sequential(
519
[
520
keras.Input(shape=(784,), name="digits"),
521
keras.layers.Dense(64, activation="relu", name="dense_1"),
522
keras.layers.Dense(64, activation="relu", name="dense_2"),
523
keras.layers.Dense(10, name="predictions"),
524
]
525
)
526
sequential_model.save_weights("my_model.weights.h5")
527
sequential_model.load_weights("my_model.weights.h5")
528
529
"""
530
Note that changing `layer.trainable` may result in a different
531
`layer.weights` ordering when the model contains nested layers.
532
"""
533
534
535
class NestedDenseLayer(keras.layers.Layer):
536
def __init__(self, units, name=None):
537
super().__init__(name=name)
538
self.dense_1 = keras.layers.Dense(units, name="dense_1")
539
self.dense_2 = keras.layers.Dense(units, name="dense_2")
540
541
def call(self, inputs):
542
return self.dense_2(self.dense_1(inputs))
543
544
545
nested_model = keras.Sequential([keras.Input((784,)), NestedDenseLayer(10, "nested")])
546
variable_names = [v.name for v in nested_model.weights]
547
print("variables: {}".format(variable_names))
548
549
print("\nChanging trainable status of one of the nested layers...")
550
nested_model.get_layer("nested").dense_1.trainable = False
551
552
variable_names_2 = [v.name for v in nested_model.weights]
553
print("\nvariables: {}".format(variable_names_2))
554
print("variable ordering changed:", variable_names != variable_names_2)
555
556
"""
557
##### **Transfer learning example**
558
559
When loading pretrained weights from a weights file, it is recommended to load
560
the weights into the original checkpointed model, and then extract
561
the desired weights/layers into a new model.
562
563
**Example:**
564
"""
565
566
567
def create_functional_model():
568
inputs = keras.Input(shape=(784,), name="digits")
569
x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)
570
x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
571
outputs = keras.layers.Dense(10, name="predictions")(x)
572
return keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")
573
574
575
functional_model = create_functional_model()
576
functional_model.save_weights("pretrained.weights.h5")
577
578
# In a separate program:
579
pretrained_model = create_functional_model()
580
pretrained_model.load_weights("pretrained.weights.h5")
581
582
# Create a new model by extracting layers from the original model:
583
extracted_layers = pretrained_model.layers[:-1]
584
extracted_layers.append(keras.layers.Dense(5, name="dense_3"))
585
model = keras.Sequential(extracted_layers)
586
model.summary()
587
588
"""
589
### Appendix: Handling custom objects
590
591
<a name="config_methods"></a>
592
#### Defining the config methods
593
594
Specifications:
595
596
* `get_config()` should return a JSON-serializable dictionary in order to be
597
compatible with the Keras architecture- and model-saving APIs.
598
* `from_config(config)` (a `classmethod`) should return a new layer or model
599
object that is created from the config.
600
The default implementation returns `cls(**config)`.
601
602
**NOTE**: If all your constructor arguments are already serializable, e.g. strings and
603
ints, or non-custom Keras objects, overriding `from_config` is not necessary. However,
604
for more complex objects such as layers or models passed to `__init__`, deserialization
605
must be handled explicitly either in `__init__` itself or overriding the `from_config()`
606
method.
607
608
**Example:**
609
"""
610
611
612
@keras.saving.register_keras_serializable(package="MyLayers", name="KernelMult")
613
class MyDense(keras.layers.Layer):
614
def __init__(
615
self,
616
units,
617
*,
618
kernel_regularizer=None,
619
kernel_initializer=None,
620
nested_model=None,
621
**kwargs
622
):
623
super().__init__(**kwargs)
624
self.hidden_units = units
625
self.kernel_regularizer = kernel_regularizer
626
self.kernel_initializer = kernel_initializer
627
self.nested_model = nested_model
628
629
def get_config(self):
630
config = super().get_config()
631
# Update the config with the custom layer's parameters
632
config.update(
633
{
634
"units": self.hidden_units,
635
"kernel_regularizer": self.kernel_regularizer,
636
"kernel_initializer": self.kernel_initializer,
637
"nested_model": self.nested_model,
638
}
639
)
640
return config
641
642
def build(self, input_shape):
643
input_units = input_shape[-1]
644
self.kernel = self.add_weight(
645
name="kernel",
646
shape=(input_units, self.hidden_units),
647
regularizer=self.kernel_regularizer,
648
initializer=self.kernel_initializer,
649
)
650
651
def call(self, inputs):
652
return ops.matmul(inputs, self.kernel)
653
654
655
layer = MyDense(units=16, kernel_regularizer="l1", kernel_initializer="ones")
656
layer3 = MyDense(units=64, nested_model=layer)
657
658
config = keras.layers.serialize(layer3)
659
660
print(config)
661
662
new_layer = keras.layers.deserialize(config)
663
664
print(new_layer)
665
666
"""
667
Note that overriding `from_config` is unnecessary above for `MyDense` because
668
`hidden_units`, `kernel_initializer`, and `kernel_regularizer` are ints, strings, and a
669
built-in Keras object, respectively. This means that the default `from_config`
670
implementation of `cls(**config)` will work as intended.
671
672
For more complex objects, such as layers and models passed to `__init__`, for
673
example, you must explicitly deserialize these objects. Let's take a look at an example
674
of a model where a `from_config` override is necessary.
675
676
**Example:**
677
<a name="registration_example"></a>
678
"""
679
680
681
@keras.saving.register_keras_serializable(package="ComplexModels")
682
class CustomModel(keras.layers.Layer):
683
def __init__(self, first_layer, second_layer=None, **kwargs):
684
super().__init__(**kwargs)
685
self.first_layer = first_layer
686
if second_layer is not None:
687
self.second_layer = second_layer
688
else:
689
self.second_layer = keras.layers.Dense(8)
690
691
def get_config(self):
692
config = super().get_config()
693
config.update(
694
{
695
"first_layer": self.first_layer,
696
"second_layer": self.second_layer,
697
}
698
)
699
return config
700
701
@classmethod
702
def from_config(cls, config):
703
# Note that you can also use `keras.saving.deserialize_keras_object` here
704
config["first_layer"] = keras.layers.deserialize(config["first_layer"])
705
config["second_layer"] = keras.layers.deserialize(config["second_layer"])
706
return cls(**config)
707
708
def call(self, inputs):
709
return self.first_layer(self.second_layer(inputs))
710
711
712
# Let's make our first layer the custom layer from the previous example (MyDense)
713
inputs = keras.Input((32,))
714
outputs = CustomModel(first_layer=layer)(inputs)
715
model = keras.Model(inputs, outputs)
716
717
config = model.get_config()
718
new_model = keras.Model.from_config(config)
719
720
"""
721
<a name="custom_object_serialization"></a>
722
#### How custom objects are serialized
723
724
The serialization format has a special key for custom objects registered via
725
`@keras.saving.register_keras_serializable`. This `registered_name` key allows for easy
726
retrieval at loading/deserialization time while also allowing users to add custom naming.
727
728
Let's take a look at the config from serializing the custom layer `MyDense` we defined
729
above.
730
731
**Example**:
732
"""
733
734
layer = MyDense(
735
units=16,
736
kernel_regularizer=keras.regularizers.L1L2(l1=1e-5, l2=1e-4),
737
kernel_initializer="ones",
738
)
739
config = keras.layers.serialize(layer)
740
print(config)
741
742
"""
743
As shown, the `registered_name` key contains the lookup information for the Keras master
744
list, including the package `MyLayers` and the custom name `KernelMult` that we gave in
745
the `@keras.saving.register_keras_serializable` decorator. Take a look again at the custom
746
class definition/registration [here](#registration_example).
747
748
Note that the `class_name` key contains the original name of the class, allowing for
749
proper re-initialization in `from_config`.
750
751
Additionally, note that the `module` key is `None` since this is a custom object.
752
"""
753
754