Path: blob/master/guides/serialization_and_saving.py
3273 views
"""1Title: Save, serialize, and export models2Authors: Neel Kovelamudi, Francois Chollet3Date created: 2023/06/144Last modified: 2023/06/305Description: Complete guide to saving, serializing, and exporting models.6Accelerator: None7"""89"""10## Introduction1112A Keras model consists of multiple components:1314- The architecture, or configuration, which specifies what layers the model15contain, and how they're connected.16- A set of weights values (the "state of the model").17- An optimizer (defined by compiling the model).18- A set of losses and metrics (defined by compiling the model).1920The Keras API saves all of these pieces together in a unified format,21marked by the `.keras` extension. This is a zip archive consisting of the22following:2324- A JSON-based configuration file (config.json): Records of model, layer, and25other trackables' configuration.26- A H5-based state file, such as `model.weights.h5` (for the whole model),27with directory keys for layers and their weights.28- A metadata file in JSON, storing things such as the current Keras version.2930Let's take a look at how this works.31"""3233"""34## How to save and load a model3536If you only have 10 seconds to read this guide, here's what you need to know.3738**Saving a Keras model:**3940```python41model = ... # Get model (Sequential, Functional Model, or Model subclass)42model.save('path/to/location.keras') # The file needs to end with the .keras extension43```4445**Loading the model back:**4647```python48model = keras.models.load_model('path/to/location.keras')49```5051Now, let's look at the details.52"""5354"""55## Setup56"""5758import numpy as np59import keras60from keras import ops6162"""63## Saving6465This section is about saving an entire model to a single file. The file will include:6667- The model's architecture/config68- The model's weight values (which were learned during training)69- The model's compilation information (if `compile()` was called)70- The optimizer and its state, if any (this enables you to restart training71where you left)7273#### APIs7475You can save a model with `model.save()` or `keras.models.save_model()` (which is equivalent).76You can load it back with `keras.models.load_model()`.7778The only supported format in Keras 3 is the "Keras v3" format,79which uses the `.keras` extension.8081**Example:**82"""838485def get_model():86# Create a simple model.87inputs = keras.Input(shape=(32,))88outputs = keras.layers.Dense(1)(inputs)89model = keras.Model(inputs, outputs)90model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error")91return model929394model = get_model()9596# Train the model.97test_input = np.random.random((128, 32))98test_target = np.random.random((128, 1))99model.fit(test_input, test_target)100101# Calling `save('my_model.keras')` creates a zip archive `my_model.keras`.102model.save("my_model.keras")103104# It can be used to reconstruct the model identically.105reconstructed_model = keras.models.load_model("my_model.keras")106107# Let's check:108np.testing.assert_allclose(109model.predict(test_input), reconstructed_model.predict(test_input)110)111112"""113### Custom objects114115This section covers the basic workflows for handling custom layers, functions, and116models in Keras saving and reloading.117118When saving a model that includes custom objects, such as a subclassed Layer,119you **must** define a `get_config()` method on the object class.120If the arguments passed to the constructor (`__init__()` method) of the custom object121aren't Python objects (anything other than base types like ints, strings,122etc.), then you **must** also explicitly deserialize these arguments in the `from_config()`123class method.124125Like this:126127```python128class CustomLayer(keras.layers.Layer):129def __init__(self, sublayer, **kwargs):130super().__init__(**kwargs)131self.sublayer = sublayer132133def call(self, x):134return self.sublayer(x)135136def get_config(self):137base_config = super().get_config()138config = {139"sublayer": keras.saving.serialize_keras_object(self.sublayer),140}141return {**base_config, **config}142143@classmethod144def from_config(cls, config):145sublayer_config = config.pop("sublayer")146sublayer = keras.saving.deserialize_keras_object(sublayer_config)147return cls(sublayer, **config)148```149150Please see the [Defining the config methods section](#config_methods) for more151details and examples.152153The saved `.keras` file is lightweight and does not store the Python code for custom154objects. Therefore, to reload the model, `load_model` requires access to the definition155of any custom objects used through one of the following methods:1561571. Registering custom objects **(preferred)**,1582. Passing custom objects directly when loading, or1593. Using a custom object scope160161Below are examples of each workflow:162163#### Registering custom objects (**preferred**)164165This is the preferred method, as custom object registration greatly simplifies saving and166loading code. Adding the `@keras.saving.register_keras_serializable` decorator to the167class definition of a custom object registers the object globally in a master list,168allowing Keras to recognize the object when loading the model.169170Let's create a custom model involving both a custom layer and a custom activation171function to demonstrate this.172173**Example:**174"""175176# Clear all previously registered custom objects177keras.saving.get_custom_objects().clear()178179180# Upon registration, you can optionally specify a package or a name.181# If left blank, the package defaults to `Custom` and the name defaults to182# the class name.183@keras.saving.register_keras_serializable(package="MyLayers")184class CustomLayer(keras.layers.Layer):185def __init__(self, factor):186super().__init__()187self.factor = factor188189def call(self, x):190return x * self.factor191192def get_config(self):193return {"factor": self.factor}194195196@keras.saving.register_keras_serializable(package="my_package", name="custom_fn")197def custom_fn(x):198return x**2199200201# Create the model.202def get_model():203inputs = keras.Input(shape=(4,))204mid = CustomLayer(0.5)(inputs)205outputs = keras.layers.Dense(1, activation=custom_fn)(mid)206model = keras.Model(inputs, outputs)207model.compile(optimizer="rmsprop", loss="mean_squared_error")208return model209210211# Train the model.212def train_model(model):213input = np.random.random((4, 4))214target = np.random.random((4, 1))215model.fit(input, target)216return model217218219test_input = np.random.random((4, 4))220test_target = np.random.random((4, 1))221222model = get_model()223model = train_model(model)224model.save("custom_model.keras")225226# Now, we can simply load without worrying about our custom objects.227reconstructed_model = keras.models.load_model("custom_model.keras")228229# Let's check:230np.testing.assert_allclose(231model.predict(test_input), reconstructed_model.predict(test_input)232)233234"""235#### Passing custom objects to `load_model()`236"""237238model = get_model()239model = train_model(model)240241# Calling `save('my_model.keras')` creates a zip archive `my_model.keras`.242model.save("custom_model.keras")243244# Upon loading, pass a dict containing the custom objects used in the245# `custom_objects` argument of `keras.models.load_model()`.246reconstructed_model = keras.models.load_model(247"custom_model.keras",248custom_objects={"CustomLayer": CustomLayer, "custom_fn": custom_fn},249)250251# Let's check:252np.testing.assert_allclose(253model.predict(test_input), reconstructed_model.predict(test_input)254)255256257"""258#### Using a custom object scope259260Any code within the custom object scope will be able to recognize the custom objects261passed to the scope argument. Therefore, loading the model within the scope will allow262the loading of our custom objects.263264**Example:**265"""266267model = get_model()268model = train_model(model)269model.save("custom_model.keras")270271# Pass the custom objects dictionary to a custom object scope and place272# the `keras.models.load_model()` call within the scope.273custom_objects = {"CustomLayer": CustomLayer, "custom_fn": custom_fn}274275with keras.saving.custom_object_scope(custom_objects):276reconstructed_model = keras.models.load_model("custom_model.keras")277278# Let's check:279np.testing.assert_allclose(280model.predict(test_input), reconstructed_model.predict(test_input)281)282283"""284### Model serialization285286This section is about saving only the model's configuration, without its state.287The model's configuration (or architecture) specifies what layers the model288contains, and how these layers are connected. If you have the configuration of a model,289then the model can be created with a freshly initialized state (no weights or compilation290information).291292#### APIs293294The following serialization APIs are available:295296- `keras.models.clone_model(model)`: make a (randomly initialized) copy of a model.297- `get_config()` and `cls.from_config()`: retrieve the configuration of a layer or model, and recreate298a model instance from its config, respectively.299- `keras.models.model_to_json()` and `keras.models.model_from_json()`: similar, but as JSON strings.300- `keras.saving.serialize_keras_object()`: retrieve the configuration any arbitrary Keras object.301- `keras.saving.deserialize_keras_object()`: recreate an object instance from its configuration.302303#### In-memory model cloning304305You can do in-memory cloning of a model via `keras.models.clone_model()`.306This is equivalent to getting the config then recreating the model from its config307(so it does not preserve compilation information or layer weights values).308309**Example:**310"""311312new_model = keras.models.clone_model(model)313314"""315#### `get_config()` and `from_config()`316317Calling `model.get_config()` or `layer.get_config()` will return a Python dict containing318the configuration of the model or layer, respectively. You should define `get_config()`319to contain arguments needed for the `__init__()` method of the model or layer. At loading time,320the `from_config(config)` method will then call `__init__()` with these arguments to321reconstruct the model or layer.322323324**Layer example:**325"""326327layer = keras.layers.Dense(3, activation="relu")328layer_config = layer.get_config()329print(layer_config)330331"""332Now let's reconstruct the layer using the `from_config()` method:333"""334335new_layer = keras.layers.Dense.from_config(layer_config)336337"""338**Sequential model example:**339"""340341model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])342config = model.get_config()343new_model = keras.Sequential.from_config(config)344345"""346**Functional model example:**347"""348349inputs = keras.Input((32,))350outputs = keras.layers.Dense(1)(inputs)351model = keras.Model(inputs, outputs)352config = model.get_config()353new_model = keras.Model.from_config(config)354355"""356#### `to_json()` and `keras.models.model_from_json()`357358This is similar to `get_config` / `from_config`, except it turns the model359into a JSON string, which can then be loaded without the original model class.360It is also specific to models, it isn't meant for layers.361362**Example:**363"""364365model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])366json_config = model.to_json()367new_model = keras.models.model_from_json(json_config)368369370"""371#### Arbitrary object serialization and deserialization372373The `keras.saving.serialize_keras_object()` and `keras.saving.deserialize_keras_object()`374APIs are general-purpose APIs that can be used to serialize or deserialize any Keras375object and any custom object. It is at the foundation of saving model architecture and is376behind all `serialize()`/`deserialize()` calls in keras.377378**Example**:379"""380381my_reg = keras.regularizers.L1(0.005)382config = keras.saving.serialize_keras_object(my_reg)383print(config)384385"""386Note the serialization format containing all the necessary information for proper387reconstruction:388389- `module` containing the name of the Keras module or other identifying module the object390comes from391- `class_name` containing the name of the object's class.392- `config` with all the information needed to reconstruct the object393- `registered_name` for custom objects. See [here](#custom_object_serialization).394395Now we can reconstruct the regularizer.396"""397398new_reg = keras.saving.deserialize_keras_object(config)399400"""401### Model weights saving402403You can choose to only save & load a model's weights. This can be useful if:404405- You only need the model for inference: in this case you won't need to406restart training, so you don't need the compilation information or optimizer state.407- You are doing transfer learning: in this case you will be training a new model408reusing the state of a prior model, so you don't need the compilation409information of the prior model.410411#### APIs for in-memory weight transfer412413Weights can be copied between different objects by using `get_weights()`414and `set_weights()`:415416* `keras.layers.Layer.get_weights()`: Returns a list of NumPy arrays of weight values.417* `keras.layers.Layer.set_weights(weights)`: Sets the model weights to the values418provided (as NumPy arrays).419420Examples:421422***Transferring weights from one layer to another, in memory***423"""424425426def create_layer():427layer = keras.layers.Dense(64, activation="relu", name="dense_2")428layer.build((None, 784))429return layer430431432layer_1 = create_layer()433layer_2 = create_layer()434435# Copy weights from layer 1 to layer 2436layer_2.set_weights(layer_1.get_weights())437438"""439***Transferring weights from one model to another model with a compatible architecture, in memory***440"""441442# Create a simple functional model443inputs = keras.Input(shape=(784,), name="digits")444x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)445x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)446outputs = keras.layers.Dense(10, name="predictions")(x)447functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")448449450# Define a subclassed model with the same architecture451class SubclassedModel(keras.Model):452def __init__(self, output_dim, name=None):453super().__init__(name=name)454self.output_dim = output_dim455self.dense_1 = keras.layers.Dense(64, activation="relu", name="dense_1")456self.dense_2 = keras.layers.Dense(64, activation="relu", name="dense_2")457self.dense_3 = keras.layers.Dense(output_dim, name="predictions")458459def call(self, inputs):460x = self.dense_1(inputs)461x = self.dense_2(x)462x = self.dense_3(x)463return x464465def get_config(self):466return {"output_dim": self.output_dim, "name": self.name}467468469subclassed_model = SubclassedModel(10)470# Call the subclassed model once to create the weights.471subclassed_model(np.ones((1, 784)))472473# Copy weights from functional_model to subclassed_model.474subclassed_model.set_weights(functional_model.get_weights())475476assert len(functional_model.weights) == len(subclassed_model.weights)477for a, b in zip(functional_model.weights, subclassed_model.weights):478np.testing.assert_allclose(a.numpy(), b.numpy())479480"""481***The case of stateless layers***482483Because stateless layers do not change the order or number of weights,484models can have compatible architectures even if there are extra/missing485stateless layers.486"""487488inputs = keras.Input(shape=(784,), name="digits")489x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)490x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)491outputs = keras.layers.Dense(10, name="predictions")(x)492functional_model = keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")493494inputs = keras.Input(shape=(784,), name="digits")495x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)496x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)497498# Add a dropout layer, which does not contain any weights.499x = keras.layers.Dropout(0.5)(x)500outputs = keras.layers.Dense(10, name="predictions")(x)501functional_model_with_dropout = keras.Model(502inputs=inputs, outputs=outputs, name="3_layer_mlp"503)504505functional_model_with_dropout.set_weights(functional_model.get_weights())506507"""508#### APIs for saving weights to disk & loading them back509510Weights can be saved to disk by calling `model.save_weights(filepath)`.511The filename should end in `.weights.h5`.512513**Example:**514"""515516# Runnable example517sequential_model = keras.Sequential(518[519keras.Input(shape=(784,), name="digits"),520keras.layers.Dense(64, activation="relu", name="dense_1"),521keras.layers.Dense(64, activation="relu", name="dense_2"),522keras.layers.Dense(10, name="predictions"),523]524)525sequential_model.save_weights("my_model.weights.h5")526sequential_model.load_weights("my_model.weights.h5")527528"""529Note that changing `layer.trainable` may result in a different530`layer.weights` ordering when the model contains nested layers.531"""532533534class NestedDenseLayer(keras.layers.Layer):535def __init__(self, units, name=None):536super().__init__(name=name)537self.dense_1 = keras.layers.Dense(units, name="dense_1")538self.dense_2 = keras.layers.Dense(units, name="dense_2")539540def call(self, inputs):541return self.dense_2(self.dense_1(inputs))542543544nested_model = keras.Sequential([keras.Input((784,)), NestedDenseLayer(10, "nested")])545variable_names = [v.name for v in nested_model.weights]546print("variables: {}".format(variable_names))547548print("\nChanging trainable status of one of the nested layers...")549nested_model.get_layer("nested").dense_1.trainable = False550551variable_names_2 = [v.name for v in nested_model.weights]552print("\nvariables: {}".format(variable_names_2))553print("variable ordering changed:", variable_names != variable_names_2)554555"""556##### **Transfer learning example**557558When loading pretrained weights from a weights file, it is recommended to load559the weights into the original checkpointed model, and then extract560the desired weights/layers into a new model.561562**Example:**563"""564565566def create_functional_model():567inputs = keras.Input(shape=(784,), name="digits")568x = keras.layers.Dense(64, activation="relu", name="dense_1")(inputs)569x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)570outputs = keras.layers.Dense(10, name="predictions")(x)571return keras.Model(inputs=inputs, outputs=outputs, name="3_layer_mlp")572573574functional_model = create_functional_model()575functional_model.save_weights("pretrained.weights.h5")576577# In a separate program:578pretrained_model = create_functional_model()579pretrained_model.load_weights("pretrained.weights.h5")580581# Create a new model by extracting layers from the original model:582extracted_layers = pretrained_model.layers[:-1]583extracted_layers.append(keras.layers.Dense(5, name="dense_3"))584model = keras.Sequential(extracted_layers)585model.summary()586587"""588### Appendix: Handling custom objects589590<a name="config_methods"></a>591#### Defining the config methods592593Specifications:594595* `get_config()` should return a JSON-serializable dictionary in order to be596compatible with the Keras architecture- and model-saving APIs.597* `from_config(config)` (a `classmethod`) should return a new layer or model598object that is created from the config.599The default implementation returns `cls(**config)`.600601**NOTE**: If all your constructor arguments are already serializable, e.g. strings and602ints, or non-custom Keras objects, overriding `from_config` is not necessary. However,603for more complex objects such as layers or models passed to `__init__`, deserialization604must be handled explicitly either in `__init__` itself or overriding the `from_config()`605method.606607**Example:**608"""609610611@keras.saving.register_keras_serializable(package="MyLayers", name="KernelMult")612class MyDense(keras.layers.Layer):613def __init__(614self,615units,616*,617kernel_regularizer=None,618kernel_initializer=None,619nested_model=None,620**kwargs621):622super().__init__(**kwargs)623self.hidden_units = units624self.kernel_regularizer = kernel_regularizer625self.kernel_initializer = kernel_initializer626self.nested_model = nested_model627628def get_config(self):629config = super().get_config()630# Update the config with the custom layer's parameters631config.update(632{633"units": self.hidden_units,634"kernel_regularizer": self.kernel_regularizer,635"kernel_initializer": self.kernel_initializer,636"nested_model": self.nested_model,637}638)639return config640641def build(self, input_shape):642input_units = input_shape[-1]643self.kernel = self.add_weight(644name="kernel",645shape=(input_units, self.hidden_units),646regularizer=self.kernel_regularizer,647initializer=self.kernel_initializer,648)649650def call(self, inputs):651return ops.matmul(inputs, self.kernel)652653654layer = MyDense(units=16, kernel_regularizer="l1", kernel_initializer="ones")655layer3 = MyDense(units=64, nested_model=layer)656657config = keras.layers.serialize(layer3)658659print(config)660661new_layer = keras.layers.deserialize(config)662663print(new_layer)664665"""666Note that overriding `from_config` is unnecessary above for `MyDense` because667`hidden_units`, `kernel_initializer`, and `kernel_regularizer` are ints, strings, and a668built-in Keras object, respectively. This means that the default `from_config`669implementation of `cls(**config)` will work as intended.670671For more complex objects, such as layers and models passed to `__init__`, for672example, you must explicitly deserialize these objects. Let's take a look at an example673of a model where a `from_config` override is necessary.674675**Example:**676<a name="registration_example"></a>677"""678679680@keras.saving.register_keras_serializable(package="ComplexModels")681class CustomModel(keras.layers.Layer):682def __init__(self, first_layer, second_layer=None, **kwargs):683super().__init__(**kwargs)684self.first_layer = first_layer685if second_layer is not None:686self.second_layer = second_layer687else:688self.second_layer = keras.layers.Dense(8)689690def get_config(self):691config = super().get_config()692config.update(693{694"first_layer": self.first_layer,695"second_layer": self.second_layer,696}697)698return config699700@classmethod701def from_config(cls, config):702# Note that you can also use `keras.saving.deserialize_keras_object` here703config["first_layer"] = keras.layers.deserialize(config["first_layer"])704config["second_layer"] = keras.layers.deserialize(config["second_layer"])705return cls(**config)706707def call(self, inputs):708return self.first_layer(self.second_layer(inputs))709710711# Let's make our first layer the custom layer from the previous example (MyDense)712inputs = keras.Input((32,))713outputs = CustomModel(first_layer=layer)(inputs)714model = keras.Model(inputs, outputs)715716config = model.get_config()717new_model = keras.Model.from_config(config)718719"""720<a name="custom_object_serialization"></a>721#### How custom objects are serialized722723The serialization format has a special key for custom objects registered via724`@keras.saving.register_keras_serializable`. This `registered_name` key allows for easy725retrieval at loading/deserialization time while also allowing users to add custom naming.726727Let's take a look at the config from serializing the custom layer `MyDense` we defined728above.729730**Example**:731"""732733layer = MyDense(734units=16,735kernel_regularizer=keras.regularizers.L1L2(l1=1e-5, l2=1e-4),736kernel_initializer="ones",737)738config = keras.layers.serialize(layer)739print(config)740741"""742As shown, the `registered_name` key contains the lookup information for the Keras master743list, including the package `MyLayers` and the custom name `KernelMult` that we gave in744the `@keras.saving.register_keras_serializable` decorator. Take a look again at the custom745class definition/registration [here](#registration_example).746747Note that the `class_name` key contains the original name of the class, allowing for748proper re-initialization in `from_config`.749750Additionally, note that the `module` key is `None` since this is a custom object.751"""752753754