Path: blob/master/guides/customizing_saving_and_serialization.py
3273 views
"""1Title: Customizing Saving and Serialization2Author: Neel Kovelamudi3Date created: 2023/03/154Last modified: 2023/03/155Description: A more advanced guide on customizing saving for your layers and models.6Accelerator: None7"""89"""10## Introduction1112This guide covers advanced methods that can be customized in Keras saving. For most13users, the methods outlined in the primary14[Serialize, save, and export guide](https://keras.io/guides/serialization_and_saving)15are sufficient.16"""1718"""19### APIs20We will cover the following APIs:2122- `save_assets()` and `load_assets()`23- `save_own_variables()` and `load_own_variables()`24- `get_build_config()` and `build_from_config()`25- `get_compile_config()` and `compile_from_config()`2627When restoring a model, these get executed in the following order:2829- `build_from_config()`30- `compile_from_config()`31- `load_own_variables()`32- `load_assets()`3334"""3536"""37## Setup38"""3940import os41import numpy as np42import keras4344"""45## State saving customization4647These methods determine how the state of your model's layers is saved when calling48`model.save()`. You can override them to take full control of the state saving process.49"""5051"""52### `save_own_variables()` and `load_own_variables()`5354These methods save and load the state variables of the layer when `model.save()` and55`keras.models.load_model()` are called, respectively. By default, the state variables56saved and loaded are the weights of the layer (both trainable and non-trainable). Here is57the default implementation of `save_own_variables()`:5859```python60def save_own_variables(self, store):61all_vars = self._trainable_weights + self._non_trainable_weights62for i, v in enumerate(all_vars):63store[f"{i}"] = v.numpy()64```6566The store used by these methods is a dictionary that can be populated with the layer67variables. Let's take a look at an example customizing this.6869**Example:**70"""717273@keras.utils.register_keras_serializable(package="my_custom_package")74class LayerWithCustomVariable(keras.layers.Dense):75def __init__(self, units, **kwargs):76super().__init__(units, **kwargs)77self.my_variable = keras.Variable(78np.random.random((units,)), name="my_variable", dtype="float32"79)8081def save_own_variables(self, store):82super().save_own_variables(store)83# Stores the value of the variable upon saving84store["variables"] = self.my_variable.numpy()8586def load_own_variables(self, store):87# Assigns the value of the variable upon loading88self.my_variable.assign(store["variables"])89# Load the remaining weights90for i, v in enumerate(self.weights):91v.assign(store[f"{i}"])92# Note: You must specify how all variables (including layer weights)93# are loaded in `load_own_variables.`9495def call(self, inputs):96dense_out = super().call(inputs)97return dense_out + self.my_variable9899100model = keras.Sequential([LayerWithCustomVariable(1)])101102ref_input = np.random.random((8, 10))103ref_output = np.random.random((8, 10))104model.compile(optimizer="adam", loss="mean_squared_error")105model.fit(ref_input, ref_output)106107model.save("custom_vars_model.keras")108restored_model = keras.models.load_model("custom_vars_model.keras")109110np.testing.assert_allclose(111model.layers[0].my_variable.numpy(),112restored_model.layers[0].my_variable.numpy(),113)114115"""116### `save_assets()` and `load_assets()`117118These methods can be added to your model class definition to store and load any119additional information that your model needs.120121For example, NLP domain layers such as TextVectorization layers and IndexLookup layers122may need to store their associated vocabulary (or lookup table) in a text file upon123saving.124125Let's take at the basics of this workflow with a simple file `assets.txt`.126127**Example:**128"""129130131@keras.saving.register_keras_serializable(package="my_custom_package")132class LayerWithCustomAssets(keras.layers.Dense):133def __init__(self, vocab=None, *args, **kwargs):134super().__init__(*args, **kwargs)135self.vocab = vocab136137def save_assets(self, inner_path):138# Writes the vocab (sentence) to text file at save time.139with open(os.path.join(inner_path, "vocabulary.txt"), "w") as f:140f.write(self.vocab)141142def load_assets(self, inner_path):143# Reads the vocab (sentence) from text file at load time.144with open(os.path.join(inner_path, "vocabulary.txt"), "r") as f:145text = f.read()146self.vocab = text.replace("<unk>", "little")147148149model = keras.Sequential(150[LayerWithCustomAssets(vocab="Mary had a <unk> lamb.", units=5)]151)152153x = np.random.random((10, 10))154y = model(x)155156model.save("custom_assets_model.keras")157restored_model = keras.models.load_model("custom_assets_model.keras")158159np.testing.assert_string_equal(160restored_model.layers[0].vocab, "Mary had a little lamb."161)162163"""164## `build` and `compile` saving customization165166### `get_build_config()` and `build_from_config()`167168These methods work together to save the layer's built states and restore them upon169loading.170171By default, this only includes a build config dictionary with the layer's input shape,172but overriding these methods can be used to include further Variables and Lookup Tables173that can be useful to restore for your built model.174175**Example:**176"""177178179@keras.saving.register_keras_serializable(package="my_custom_package")180class LayerWithCustomBuild(keras.layers.Layer):181def __init__(self, units=32, **kwargs):182super().__init__(**kwargs)183self.units = units184185def call(self, inputs):186return keras.ops.matmul(inputs, self.w) + self.b187188def get_config(self):189return dict(units=self.units, **super().get_config())190191def build(self, input_shape, layer_init):192# Note the overriding of `build()` to add an extra argument.193# Therefore, we will need to manually call build with `layer_init` argument194# before the first execution of `call()`.195super().build(input_shape)196self._input_shape = input_shape197self.w = self.add_weight(198shape=(input_shape[-1], self.units),199initializer=layer_init,200trainable=True,201)202self.b = self.add_weight(203shape=(self.units,),204initializer=layer_init,205trainable=True,206)207self.layer_init = layer_init208209def get_build_config(self):210build_config = {211"layer_init": self.layer_init,212"input_shape": self._input_shape,213} # Stores our initializer for `build()`214return build_config215216def build_from_config(self, config):217# Calls `build()` with the parameters at loading time218self.build(config["input_shape"], config["layer_init"])219220221custom_layer = LayerWithCustomBuild(units=16)222custom_layer.build(input_shape=(8,), layer_init="random_normal")223224model = keras.Sequential(225[226custom_layer,227keras.layers.Dense(1, activation="sigmoid"),228]229)230231x = np.random.random((16, 8))232y = model(x)233234model.save("custom_build_model.keras")235restored_model = keras.models.load_model("custom_build_model.keras")236237np.testing.assert_equal(restored_model.layers[0].layer_init, "random_normal")238np.testing.assert_equal(restored_model.built, True)239240"""241### `get_compile_config()` and `compile_from_config()`242243These methods work together to save the information with which the model was compiled244(optimizers, losses, etc.) and restore and re-compile the model with this information.245246Overriding these methods can be useful for compiling the restored model with custom247optimizers, custom losses, etc., as these will need to be deserialized prior to calling248`model.compile` in `compile_from_config()`.249250Let's take a look at an example of this.251252**Example:**253"""254255256@keras.saving.register_keras_serializable(package="my_custom_package")257def small_square_sum_loss(y_true, y_pred):258loss = keras.ops.square(y_pred - y_true)259loss = loss / 10.0260loss = keras.ops.sum(loss, axis=1)261return loss262263264@keras.saving.register_keras_serializable(package="my_custom_package")265def mean_pred(y_true, y_pred):266return keras.ops.mean(y_pred)267268269@keras.saving.register_keras_serializable(package="my_custom_package")270class ModelWithCustomCompile(keras.Model):271def __init__(self, **kwargs):272super().__init__(**kwargs)273self.dense1 = keras.layers.Dense(8, activation="relu")274self.dense2 = keras.layers.Dense(4, activation="softmax")275276def call(self, inputs):277x = self.dense1(inputs)278return self.dense2(x)279280def compile(self, optimizer, loss_fn, metrics):281super().compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)282self.model_optimizer = optimizer283self.loss_fn = loss_fn284self.loss_metrics = metrics285286def get_compile_config(self):287# These parameters will be serialized at saving time.288return {289"model_optimizer": self.model_optimizer,290"loss_fn": self.loss_fn,291"metric": self.loss_metrics,292}293294def compile_from_config(self, config):295# Deserializes the compile parameters (important, since many are custom)296optimizer = keras.utils.deserialize_keras_object(config["model_optimizer"])297loss_fn = keras.utils.deserialize_keras_object(config["loss_fn"])298metrics = keras.utils.deserialize_keras_object(config["metric"])299300# Calls compile with the deserialized parameters301self.compile(optimizer=optimizer, loss_fn=loss_fn, metrics=metrics)302303304model = ModelWithCustomCompile()305model.compile(306optimizer="SGD", loss_fn=small_square_sum_loss, metrics=["accuracy", mean_pred]307)308309x = np.random.random((4, 8))310y = np.random.random((4,))311312model.fit(x, y)313314model.save("custom_compile_model.keras")315restored_model = keras.models.load_model("custom_compile_model.keras")316317np.testing.assert_equal(model.model_optimizer, restored_model.model_optimizer)318np.testing.assert_equal(model.loss_fn, restored_model.loss_fn)319np.testing.assert_equal(model.loss_metrics, restored_model.loss_metrics)320321"""322## Conclusion323324Using the methods learned in this tutorial allows for a wide variety of use cases,325allowing the saving and loading of complex models with exotic assets and state326elements. To recap:327328- `save_own_variables` and `load_own_variables` determine how your states are saved329and loaded.330- `save_assets` and `load_assets` can be added to store and load any additional331information your model needs.332- `get_build_config` and `build_from_config` save and restore the model's built333states.334- `get_compile_config` and `compile_from_config` save and restore the model's335compiled states.336"""337338339