Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/zh-cn/tutorials/distribute/save_and_load.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

使用分布策略保存和加载模型

概述

本教程演示了如何在训练期间或训练之后使用 tf.distribute.Strategy 以 SavedModel 格式保存和加载模型。有两种用于保存和加载 Keras 模型的 API:高级(tf.keras.Model.savetf.keras.models.load_model)和低级(tf.saved_model.savetf.saved_model.load)。

要全面了解 SavedModel 和序列化,请阅读已保存模型指南Keras 模型序列化指南。我们从一个简单的示例开始。

小心:TensorFlow 模型是代码,对于不受信任的代码,一定要小心。请参阅安全使用 TensorFlow 以了解详情。

导入依赖项:

import tensorflow_datasets as tfds import tensorflow as tf

使用 TensorFlow Datasets 和 tf.data 加载和准备数据,并使用 tf.distribute.MirroredStrategy 创建模型:

mirrored_strategy = tf.distribute.MirroredStrategy() def get_data(): datasets = tfds.load(name='mnist', as_supervised=True) mnist_train, mnist_test = datasets['train'], datasets['test'] BUFFER_SIZE = 10000 BATCH_SIZE_PER_REPLICA = 64 BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync def scale(image, label): image = tf.cast(image, tf.float32) image /= 255 return image, label train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE) eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE) return train_dataset, eval_dataset def get_model(): with mirrored_strategy.scope(): model = tf.keras.Sequential([ tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(10) ]) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=tf.keras.optimizers.Adam(), metrics=[tf.metrics.SparseCategoricalAccuracy()]) return model

使用 tf.keras.Model.fit 训练模型:

model = get_model() train_dataset, eval_dataset = get_data() model.fit(train_dataset, epochs=2)

保存和加载模型

现在,您已经有一个简单的模型可供使用,让我们探索保存/加载 API。有两种可用的 API:

  • 高级 (Keras):Model.savetf.keras.models.load_model.keras zip 存档格式)

  • 低级:tf.saved_model.savetf.saved_model.load(TF SavedModel 格式)

Keras API

以下为使用 Keras API 保存和加载模型的示例:

keras_model_path = '/tmp/keras_save.keras' model.save(keras_model_path)

恢复无 tf.distribute.Strategy 的模型:

restored_keras_model = tf.keras.models.load_model(keras_model_path) restored_keras_model.fit(train_dataset, epochs=2)

恢复模型后,您可以继在它上面续训练,甚至不需要再次调用 Model.compile,因为它在保存之前已经编译。模型以 Keras zip 存档格式保存,由 .keras 扩展名标记。有关详情,请参阅 Keras 保存指南

现在,恢复模型并使用 tf.distribute.Strategy 对其进行训练:

another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0') with another_strategy.scope(): restored_keras_model_ds = tf.keras.models.load_model(keras_model_path) restored_keras_model_ds.fit(train_dataset, epochs=2)

正如 Model.fit 输出所示,tf.distribute.Strategy 可以按预期进行加载。此处使用的策略不必与保存前所用策略相同。

tf.saved_model API

使用较低级别的 API 保存模型类似于 Keras API:

model = get_model() # get a fresh model saved_model_path = '/tmp/tf_save' tf.saved_model.save(model, saved_model_path)

可以使用 tf.saved_model.load 进行加载。但是,由于它是一个较低级别的 API(因此用例范围更广泛),不会返回 Keras 模型。相反,它会返回一个对象,其中包含可用于进行推断的函数。例如:

DEFAULT_FUNCTION_KEY = 'serving_default' loaded = tf.saved_model.load(saved_model_path) inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

加载的对象可能包含多个函数,每个函数与一个键关联。"serving_default" 键是使用已保存的 Keras 模型的推断函数的默认键。要使用此函数进行推断,请运行以下代码:

predict_dataset = eval_dataset.map(lambda image, label: image) for batch in predict_dataset.take(1): print(inference_func(batch))

您还可以采用分布式方式加载和进行推断:

another_strategy = tf.distribute.MirroredStrategy() with another_strategy.scope(): loaded = tf.saved_model.load(saved_model_path) inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY] dist_predict_dataset = another_strategy.experimental_distribute_dataset( predict_dataset) # Calling the function in a distributed manner for batch in dist_predict_dataset: result = another_strategy.run(inference_func, args=(batch,)) print(result) break

调用已恢复的函数只是基于已保存模型的前向传递 (tf.keras.Model.predict)。如果您想继续训练加载的函数,或者将加载的函数嵌入到更大的模型中,应如何操作?通常的做法是将此加载对象封装到 Keras 层以实现此目的。幸运的是,TF Hub 为此提供了 hub.KerasLayer,如下所示:

import tensorflow_hub as hub def build_model(loaded): x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x') # Wrap what's loaded to a KerasLayer keras_layer = hub.KerasLayer(loaded, trainable=True)(x) model = tf.keras.Model(x, keras_layer) return model another_strategy = tf.distribute.MirroredStrategy() with another_strategy.scope(): loaded = tf.saved_model.load(saved_model_path) model = build_model(loaded) model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=tf.keras.optimizers.Adam(), metrics=[tf.metrics.SparseCategoricalAccuracy()]) model.fit(train_dataset, epochs=2)

在上面的示例中,TensorFlow Hub 的 hub.KerasLayer 可将从 tf.saved_model.load 加载回的结果封装到可用于构建其他模型的 Keras 层。这对于迁移学习非常实用。

我应使用哪种 API?

对于保存,如果您使用的是 Keras 模型,请使用 Keras Model.save API,除非您需要低级 API 允许的额外控制。如果您保存的不是 Keras 模型,那么您只能选择使用较低级的 API tf.saved_model.save

对于加载,您的 API 选择取决于您要从加载 API 中获得什么。如果您无法(或不想)获取 Keras 模型,请使用 tf.saved_model.load。否则,请使用 tf.keras.models.load_model。请注意,只有保存 Keras 模型后,才能恢复 Keras 模型。

可以搭配使用 API。您可以使用 model.save 保存 Keras 模型,并使用低级 API tf.saved_model.load 加载非 Keras 模型。

model = get_model() # Saving the model using Keras `Model.save` model.save(saved_model_path) another_strategy = tf.distribute.MirroredStrategy() # Loading the model using the lower-level API with another_strategy.scope(): loaded = tf.saved_model.load(saved_model_path)

从本地设备保存/加载

在远程设备上训练的过程中从本地 I/O 设备保存和加载时(例如,使用 Cloud TPU 时),必须使用 tf.saved_model.SaveOptionstf.saved_model.LoadOptions 中的选项 experimental_io_device 将 I/O 设备设置为 localhost。例如:

model = get_model() # Saving the model to a path on localhost. saved_model_path = '/tmp/tf_save' save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost') model.save(saved_model_path, options=save_options) # Loading the model from a path on localhost. another_strategy = tf.distribute.MirroredStrategy() with another_strategy.scope(): load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost') loaded = tf.keras.models.load_model(saved_model_path, options=load_options)

警告

一种特殊情况是当您以某种方式创建 Keras 模型,然后在训练之前保存它们。例如:

class SubclassedModel(tf.keras.Model): """Example model defined by subclassing `tf.keras.Model`.""" output_name = 'output_layer' def __init__(self): super(SubclassedModel, self).__init__() self._dense_layer = tf.keras.layers.Dense( 5, dtype=tf.dtypes.float32, name=self.output_name) def call(self, inputs): return self._dense_layer(inputs) my_model = SubclassedModel() try: my_model.save(saved_model_path) except ValueError as e: print(f'{type(e).__name__}: ', *e.args)

SavedModel 保存跟踪 tf.function 时生成的 tf.types.experimental.ConcreteFunction 对象(请查看计算图和 tf.function 简介指南中的*函数何时执行跟踪?*了解更多信息)。如果您收到像这样的 ValueError,那是因为 Model.save 无法找到或创建跟踪的 ConcreteFunction

**小心:**您不应在一个 ConcreteFunction 都没有的情况下保存模型,因为如果这样做,低级 API 将生成一个没有 ConcreteFunction 签名的 SavedModel(详细了解 SavedModel 格式)。例如:

tf.saved_model.save(my_model, saved_model_path) x = tf.saved_model.load(saved_model_path) x.signatures

一般而言,模型的前向传递(call 方法)会在第一次调用模型时被自动跟踪,通常是通过 Keras Model.fit 方法。如果您设置了输入形状,例如通过将第一层设为 tf.keras.layers.InputLayer 或其他层类型,并将 input_shape 关键字参数传递给它,Keras 序贯函数式 API 也可以生成 ConcreteFunction

要验证您的模型是否有任何跟踪的 ConcreteFunction,请检查 Model.save_spec 是否为 None

print(my_model.save_spec() is None)

我们使用 tf.keras.Model.fit 来训练模型,可以注意到,save_spec 被定义并且模型保存将生效:

BATCH_SIZE_PER_REPLICA = 4 BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync dataset_size = 100 dataset = tf.data.Dataset.from_tensors( (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32)) ).repeat(dataset_size).batch(BATCH_SIZE) my_model.compile(optimizer='adam', loss='mean_squared_error') my_model.fit(dataset, epochs=2) print(my_model.save_spec() is None) my_model.save(saved_model_path)