Path: blob/master/site/zh-cn/tutorials/distribute/save_and_load.ipynb
25118 views
Copyright 2019 The TensorFlow Authors.
使用分布策略保存和加载模型
概述
本教程演示了如何在训练期间或训练之后使用 tf.distribute.Strategy
以 SavedModel 格式保存和加载模型。有两种用于保存和加载 Keras 模型的 API:高级(tf.keras.Model.save
和 tf.keras.models.load_model
)和低级(tf.saved_model.save
和 tf.saved_model.load
)。
要全面了解 SavedModel 和序列化,请阅读已保存模型指南和 Keras 模型序列化指南。我们从一个简单的示例开始。
小心:TensorFlow 模型是代码,对于不受信任的代码,一定要小心。请参阅安全使用 TensorFlow 以了解详情。
导入依赖项:
使用 TensorFlow Datasets 和 tf.data
加载和准备数据,并使用 tf.distribute.MirroredStrategy
创建模型:
使用 tf.keras.Model.fit
训练模型:
保存和加载模型
现在,您已经有一个简单的模型可供使用,让我们探索保存/加载 API。有两种可用的 API:
高级 (Keras):
Model.save
和tf.keras.models.load_model
(.keras
zip 存档格式)低级:
tf.saved_model.save
和tf.saved_model.load
(TF SavedModel 格式)
Keras API
以下为使用 Keras API 保存和加载模型的示例:
恢复无 tf.distribute.Strategy
的模型:
恢复模型后,您可以继在它上面续训练,甚至不需要再次调用 Model.compile
,因为它在保存之前已经编译。模型以 Keras zip 存档格式保存,由 .keras
扩展名标记。有关详情,请参阅 Keras 保存指南。
现在,恢复模型并使用 tf.distribute.Strategy
对其进行训练:
正如 Model.fit
输出所示,tf.distribute.Strategy
可以按预期进行加载。此处使用的策略不必与保存前所用策略相同。
tf.saved_model
API
使用较低级别的 API 保存模型类似于 Keras API:
可以使用 tf.saved_model.load
进行加载。但是,由于它是一个较低级别的 API(因此用例范围更广泛),不会返回 Keras 模型。相反,它会返回一个对象,其中包含可用于进行推断的函数。例如:
加载的对象可能包含多个函数,每个函数与一个键关联。"serving_default"
键是使用已保存的 Keras 模型的推断函数的默认键。要使用此函数进行推断,请运行以下代码:
您还可以采用分布式方式加载和进行推断:
调用已恢复的函数只是基于已保存模型的前向传递 (tf.keras.Model.predict
)。如果您想继续训练加载的函数,或者将加载的函数嵌入到更大的模型中,应如何操作?通常的做法是将此加载对象封装到 Keras 层以实现此目的。幸运的是,TF Hub 为此提供了 hub.KerasLayer
,如下所示:
在上面的示例中,TensorFlow Hub 的 hub.KerasLayer
可将从 tf.saved_model.load
加载回的结果封装到可用于构建其他模型的 Keras 层。这对于迁移学习非常实用。
我应使用哪种 API?
对于保存,如果您使用的是 Keras 模型,请使用 Keras Model.save
API,除非您需要低级 API 允许的额外控制。如果您保存的不是 Keras 模型,那么您只能选择使用较低级的 API tf.saved_model.save
。
对于加载,您的 API 选择取决于您要从加载 API 中获得什么。如果您无法(或不想)获取 Keras 模型,请使用 tf.saved_model.load
。否则,请使用 tf.keras.models.load_model
。请注意,只有保存 Keras 模型后,才能恢复 Keras 模型。
可以搭配使用 API。您可以使用 model.save
保存 Keras 模型,并使用低级 API tf.saved_model.load
加载非 Keras 模型。
从本地设备保存/加载
在远程设备上训练的过程中从本地 I/O 设备保存和加载时(例如,使用 Cloud TPU 时),必须使用 tf.saved_model.SaveOptions
和 tf.saved_model.LoadOptions
中的选项 experimental_io_device
将 I/O 设备设置为 localhost
。例如:
警告
一种特殊情况是当您以某种方式创建 Keras 模型,然后在训练之前保存它们。例如:
SavedModel 保存跟踪 tf.function
时生成的 tf.types.experimental.ConcreteFunction
对象(请查看计算图和 tf.function 简介指南中的*函数何时执行跟踪?*了解更多信息)。如果您收到像这样的 ValueError
,那是因为 Model.save
无法找到或创建跟踪的 ConcreteFunction
。
**小心:**您不应在一个 ConcreteFunction
都没有的情况下保存模型,因为如果这样做,低级 API 将生成一个没有 ConcreteFunction
签名的 SavedModel(详细了解 SavedModel 格式)。例如:
我们使用 tf.keras.Model.fit
来训练模型,可以注意到,save_spec
被定义并且模型保存将生效: