Path: blob/master/site/en-snapshot/tutorials/distribute/save_and_load.ipynb
39065 views
Copyright 2019 The TensorFlow Authors.
Save and load a model using a distribution strategy
Overview
This tutorial demonstrates how you can save and load models in a SavedModel format with tf.distribute.Strategy during or after training. There are two kinds of APIs for saving and loading a Keras model: high-level (tf.keras.Model.save and tf.keras.models.load_model) and low-level (tf.saved_model.save and tf.saved_model.load).
To learn about SavedModel and serialization in general, please read the saved model guide, and the Keras model serialization guide. Let's start with a simple example.
Caution: TensorFlow models are code and it is important to be careful with untrusted code. Learn more in Using TensorFlow securely.
Import dependencies:
Load and prepare the data with TensorFlow Datasets and tf.data, and create the model using tf.distribute.MirroredStrategy:
Train the model with tf.keras.Model.fit:
Save and load the model
Now that you have a simple model to work with, let's explore the saving/loading APIs. There are two kinds of APIs available:
High-level (Keras):
Model.saveandtf.keras.models.load_model(.keraszip archive format)Low-level:
tf.saved_model.saveandtf.saved_model.load(TF SavedModel format)
The Keras API
Here is an example of saving and loading a model with the Keras API:
Restore the model without tf.distribute.Strategy:
After restoring the model, you can continue training on it, even without needing to call Model.compile again, since it was already compiled before saving. The model is saved a Keras zip archive format, marked by the .keras extension. For more information, please refer to the guide on Keras saving.
Now, restore the model and train it using a tf.distribute.Strategy:
As the Model.fit output shows, loading works as expected with tf.distribute.Strategy. The strategy used here does not have to be the same strategy used before saving.
The tf.saved_model API
Saving the model with lower-level API is similar to the Keras API:
Loading can be done with tf.saved_model.load. However, since it is a lower-level API (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:
The loaded object may contain multiple functions, each associated with a key. The "serving_default" key is the default key for the inference function with a saved Keras model. To do inference with this function:
You can also load and do inference in a distributed manner:
Calling the restored function is just a forward pass on the saved model (tf.keras.Model.predict). What if you want to continue training the loaded function? Or what if you need to embed the loaded function into a bigger model? A common practice is to wrap this loaded object into a Keras layer to achieve this. Luckily, TF Hub has hub.KerasLayer for this purpose, shown here:
In the above example, Tensorflow Hub's hub.KerasLayer wraps the result loaded back from tf.saved_model.load into a Keras layer that is used to build another model. This is very useful for transfer learning.
Which API should I use?
For saving, if you are working with a Keras model, use the Keras Model.save API unless you need the additional control allowed by the low-level API. If what you are saving is not a Keras model, then the lower-level API, tf.saved_model.save, is your only choice.
For loading, your API choice depends on what you want to get from the model loading API. If you cannot (or do not want to) get a Keras model, then use tf.saved_model.load. Otherwise, use tf.keras.models.load_model. Note that you can get a Keras model back only if you saved a Keras model.
It is possible to mix and match the APIs. You can save a Keras model with Model.save, and load a non-Keras model with the low-level API, tf.saved_model.load.
Saving/Loading from a local device
When saving and loading from a local I/O device while training on remote devices—for example, when using a Cloud TPU—you must use the option experimental_io_device in tf.saved_model.SaveOptions and tf.saved_model.LoadOptions to set the I/O device to localhost. For example:
Caveats
One special case is when you create Keras models in certain ways, and then save them before training. For example:
A SavedModel saves the tf.types.experimental.ConcreteFunction objects generated when you trace a tf.function (check When is a Function tracing? in the Introduction to graphs and tf.function guide to learn more). If you get a ValueError like this it's because Model.save was not able to find or create a traced ConcreteFunction.
Caution: You shouldn't save a model without at least one ConcreteFunction, since the low-level API will otherwise generate a SavedModel with no ConcreteFunction signatures (learn more about the SavedModel format). For example:
Usually the model's forward pass—the call method—will be traced automatically when the model is called for the first time, often via the Keras Model.fit method. A ConcreteFunction can also be generated by the Keras Sequential and Functional APIs, if you set the input shape, for example, by making the first layer either a tf.keras.layers.InputLayer or another layer type, and passing it the input_shape keyword argument.
To verify if your model has any traced ConcreteFunctions, check if Model.save_spec is None:
Let's use tf.keras.Model.fit to train the model, and notice that the save_spec gets defined and model saving will work:
View on TensorFlow.org
Run in Google Colab
View source on GitHub
Download notebook