Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/customizing_saving_and_serialization.py
3273 views
1
"""
2
Title: Customizing Saving and Serialization
3
Author: Neel Kovelamudi
4
Date created: 2023/03/15
5
Last modified: 2023/03/15
6
Description: A more advanced guide on customizing saving for your layers and models.
7
Accelerator: None
8
"""
9
10
"""
11
## Introduction
12
13
This guide covers advanced methods that can be customized in Keras saving. For most
14
users, the methods outlined in the primary
15
[Serialize, save, and export guide](https://keras.io/guides/serialization_and_saving)
16
are sufficient.
17
"""
18
19
"""
20
### APIs
21
We will cover the following APIs:
22
23
- `save_assets()` and `load_assets()`
24
- `save_own_variables()` and `load_own_variables()`
25
- `get_build_config()` and `build_from_config()`
26
- `get_compile_config()` and `compile_from_config()`
27
28
When restoring a model, these get executed in the following order:
29
30
- `build_from_config()`
31
- `compile_from_config()`
32
- `load_own_variables()`
33
- `load_assets()`
34
35
"""
36
37
"""
38
## Setup
39
"""
40
41
import os
42
import numpy as np
43
import keras
44
45
"""
46
## State saving customization
47
48
These methods determine how the state of your model's layers is saved when calling
49
`model.save()`. You can override them to take full control of the state saving process.
50
"""
51
52
"""
53
### `save_own_variables()` and `load_own_variables()`
54
55
These methods save and load the state variables of the layer when `model.save()` and
56
`keras.models.load_model()` are called, respectively. By default, the state variables
57
saved and loaded are the weights of the layer (both trainable and non-trainable). Here is
58
the default implementation of `save_own_variables()`:
59
60
```python
61
def save_own_variables(self, store):
62
all_vars = self._trainable_weights + self._non_trainable_weights
63
for i, v in enumerate(all_vars):
64
store[f"{i}"] = v.numpy()
65
```
66
67
The store used by these methods is a dictionary that can be populated with the layer
68
variables. Let's take a look at an example customizing this.
69
70
**Example:**
71
"""
72
73
74
@keras.utils.register_keras_serializable(package="my_custom_package")
75
class LayerWithCustomVariable(keras.layers.Dense):
76
def __init__(self, units, **kwargs):
77
super().__init__(units, **kwargs)
78
self.my_variable = keras.Variable(
79
np.random.random((units,)), name="my_variable", dtype="float32"
80
)
81
82
def save_own_variables(self, store):
83
super().save_own_variables(store)
84
# Stores the value of the variable upon saving
85
store["variables"] = self.my_variable.numpy()
86
87
def load_own_variables(self, store):
88
# Assigns the value of the variable upon loading
89
self.my_variable.assign(store["variables"])
90
# Load the remaining weights
91
for i, v in enumerate(self.weights):
92
v.assign(store[f"{i}"])
93
# Note: You must specify how all variables (including layer weights)
94
# are loaded in `load_own_variables.`
95
96
def call(self, inputs):
97
dense_out = super().call(inputs)
98
return dense_out + self.my_variable
99
100
101
model = keras.Sequential([LayerWithCustomVariable(1)])
102
103
ref_input = np.random.random((8, 10))
104
ref_output = np.random.random((8, 10))
105
model.compile(optimizer="adam", loss="mean_squared_error")
106
model.fit(ref_input, ref_output)
107
108
model.save("custom_vars_model.keras")
109
restored_model = keras.models.load_model("custom_vars_model.keras")
110
111
np.testing.assert_allclose(
112
model.layers[0].my_variable.numpy(),
113
restored_model.layers[0].my_variable.numpy(),
114
)
115
116
"""
117
### `save_assets()` and `load_assets()`
118
119
These methods can be added to your model class definition to store and load any
120
additional information that your model needs.
121
122
For example, NLP domain layers such as TextVectorization layers and IndexLookup layers
123
may need to store their associated vocabulary (or lookup table) in a text file upon
124
saving.
125
126
Let's take at the basics of this workflow with a simple file `assets.txt`.
127
128
**Example:**
129
"""
130
131
132
@keras.saving.register_keras_serializable(package="my_custom_package")
133
class LayerWithCustomAssets(keras.layers.Dense):
134
def __init__(self, vocab=None, *args, **kwargs):
135
super().__init__(*args, **kwargs)
136
self.vocab = vocab
137
138
def save_assets(self, inner_path):
139
# Writes the vocab (sentence) to text file at save time.
140
with open(os.path.join(inner_path, "vocabulary.txt"), "w") as f:
141
f.write(self.vocab)
142
143
def load_assets(self, inner_path):
144
# Reads the vocab (sentence) from text file at load time.
145
with open(os.path.join(inner_path, "vocabulary.txt"), "r") as f:
146
text = f.read()
147
self.vocab = text.replace("<unk>", "little")
148
149
150
model = keras.Sequential(
151
[LayerWithCustomAssets(vocab="Mary had a <unk> lamb.", units=5)]
152
)
153
154
x = np.random.random((10, 10))
155
y = model(x)
156
157
model.save("custom_assets_model.keras")
158
restored_model = keras.models.load_model("custom_assets_model.keras")
159
160
np.testing.assert_string_equal(
161
restored_model.layers[0].vocab, "Mary had a little lamb."
162
)
163
164
"""
165
## `build` and `compile` saving customization
166
167
### `get_build_config()` and `build_from_config()`
168
169
These methods work together to save the layer's built states and restore them upon
170
loading.
171
172
By default, this only includes a build config dictionary with the layer's input shape,
173
but overriding these methods can be used to include further Variables and Lookup Tables
174
that can be useful to restore for your built model.
175
176
**Example:**
177
"""
178
179
180
@keras.saving.register_keras_serializable(package="my_custom_package")
181
class LayerWithCustomBuild(keras.layers.Layer):
182
def __init__(self, units=32, **kwargs):
183
super().__init__(**kwargs)
184
self.units = units
185
186
def call(self, inputs):
187
return keras.ops.matmul(inputs, self.w) + self.b
188
189
def get_config(self):
190
return dict(units=self.units, **super().get_config())
191
192
def build(self, input_shape, layer_init):
193
# Note the overriding of `build()` to add an extra argument.
194
# Therefore, we will need to manually call build with `layer_init` argument
195
# before the first execution of `call()`.
196
super().build(input_shape)
197
self._input_shape = input_shape
198
self.w = self.add_weight(
199
shape=(input_shape[-1], self.units),
200
initializer=layer_init,
201
trainable=True,
202
)
203
self.b = self.add_weight(
204
shape=(self.units,),
205
initializer=layer_init,
206
trainable=True,
207
)
208
self.layer_init = layer_init
209
210
def get_build_config(self):
211
build_config = {
212
"layer_init": self.layer_init,
213
"input_shape": self._input_shape,
214
} # Stores our initializer for `build()`
215
return build_config
216
217
def build_from_config(self, config):
218
# Calls `build()` with the parameters at loading time
219
self.build(config["input_shape"], config["layer_init"])
220
221
222
custom_layer = LayerWithCustomBuild(units=16)
223
custom_layer.build(input_shape=(8,), layer_init="random_normal")
224
225
model = keras.Sequential(
226
[
227
custom_layer,
228
keras.layers.Dense(1, activation="sigmoid"),
229
]
230
)
231
232
x = np.random.random((16, 8))
233
y = model(x)
234
235
model.save("custom_build_model.keras")
236
restored_model = keras.models.load_model("custom_build_model.keras")
237
238
np.testing.assert_equal(restored_model.layers[0].layer_init, "random_normal")
239
np.testing.assert_equal(restored_model.built, True)
240
241
"""
242
### `get_compile_config()` and `compile_from_config()`
243
244
These methods work together to save the information with which the model was compiled
245
(optimizers, losses, etc.) and restore and re-compile the model with this information.
246
247
Overriding these methods can be useful for compiling the restored model with custom
248
optimizers, custom losses, etc., as these will need to be deserialized prior to calling
249
`model.compile` in `compile_from_config()`.
250
251
Let's take a look at an example of this.
252
253
**Example:**
254
"""
255
256
257
@keras.saving.register_keras_serializable(package="my_custom_package")
258
def small_square_sum_loss(y_true, y_pred):
259
loss = keras.ops.square(y_pred - y_true)
260
loss = loss / 10.0
261
loss = keras.ops.sum(loss, axis=1)
262
return loss
263
264
265
@keras.saving.register_keras_serializable(package="my_custom_package")
266
def mean_pred(y_true, y_pred):
267
return keras.ops.mean(y_pred)
268
269
270
@keras.saving.register_keras_serializable(package="my_custom_package")
271
class ModelWithCustomCompile(keras.Model):
272
def __init__(self, **kwargs):
273
super().__init__(**kwargs)
274
self.dense1 = keras.layers.Dense(8, activation="relu")
275
self.dense2 = keras.layers.Dense(4, activation="softmax")
276
277
def call(self, inputs):
278
x = self.dense1(inputs)
279
return self.dense2(x)
280
281
def compile(self, optimizer, loss_fn, metrics):
282
super().compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)
283
self.model_optimizer = optimizer
284
self.loss_fn = loss_fn
285
self.loss_metrics = metrics
286
287
def get_compile_config(self):
288
# These parameters will be serialized at saving time.
289
return {
290
"model_optimizer": self.model_optimizer,
291
"loss_fn": self.loss_fn,
292
"metric": self.loss_metrics,
293
}
294
295
def compile_from_config(self, config):
296
# Deserializes the compile parameters (important, since many are custom)
297
optimizer = keras.utils.deserialize_keras_object(config["model_optimizer"])
298
loss_fn = keras.utils.deserialize_keras_object(config["loss_fn"])
299
metrics = keras.utils.deserialize_keras_object(config["metric"])
300
301
# Calls compile with the deserialized parameters
302
self.compile(optimizer=optimizer, loss_fn=loss_fn, metrics=metrics)
303
304
305
model = ModelWithCustomCompile()
306
model.compile(
307
optimizer="SGD", loss_fn=small_square_sum_loss, metrics=["accuracy", mean_pred]
308
)
309
310
x = np.random.random((4, 8))
311
y = np.random.random((4,))
312
313
model.fit(x, y)
314
315
model.save("custom_compile_model.keras")
316
restored_model = keras.models.load_model("custom_compile_model.keras")
317
318
np.testing.assert_equal(model.model_optimizer, restored_model.model_optimizer)
319
np.testing.assert_equal(model.loss_fn, restored_model.loss_fn)
320
np.testing.assert_equal(model.loss_metrics, restored_model.loss_metrics)
321
322
"""
323
## Conclusion
324
325
Using the methods learned in this tutorial allows for a wide variety of use cases,
326
allowing the saving and loading of complex models with exotic assets and state
327
elements. To recap:
328
329
- `save_own_variables` and `load_own_variables` determine how your states are saved
330
and loaded.
331
- `save_assets` and `load_assets` can be added to store and load any additional
332
information your model needs.
333
- `get_build_config` and `build_from_config` save and restore the model's built
334
states.
335
- `get_compile_config` and `compile_from_config` save and restore the model's
336
compiled states.
337
"""
338
339