Path: blob/master/site/zh-cn/lite/examples/on_device_training/overview.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
使用 TensorFlow Lite 进行设备端训练
在将 TensorFlow Lite 机器学习模型部署到设备或移动应用时,您可能希望根据设备或最终用户的输入改进或个性化该模型。通过使用设备端训练技术,您可以更新模型而无需让数据离开用户的设备,从而提高用户隐私,并且无需用户更新设备软件。
例如,您的移动应用中可能有一个识别时尚物品的模型,但您希望用户根据他们的兴趣随着时间的推移获得更好的识别性能。启用设备端训练后,对鞋子感兴趣的用户越频繁地使用您的应用,他们就可以更好地识别特定款式的鞋子或鞋类品牌。
本教程将向您展示如何构建一个 TensorFlow Lite 模型,该模型可以在已安装的 Android 应用中进行增量训练和改进。
注:设备端训练技术可以添加到现有的 TensorFlow Lite 实现中,前提是您的目标设备支持本地文件存储。
安装
本教程使用 Python 训练和转换 TensorFlow 模型,然后将其整合到 Android 应用中。从安装和导入以下软件包开始。
注:TensorFlow 2.7 及更高版本中提供了设备端训练 API。
对服装图像进行分类
此示例代码使用 Fashion MNIST 数据集训练神经网络模型,用于对服装图像进行分类。该数据集包含 60,000 张小(28 x 28 像素)灰度图像,其中包含 10 种不同类别的时尚配饰,包括连衣裙、衬衫和凉鞋。
<figure> <img src="https://tensorflow.org/images/fashion-mnist-sprite.png" alt="Fashion MNIST images"> <figcaption><b>Figure 1</b>: <a href="https://github.com/zalandoresearch/fashion-mnist">Fashion-MNIST samples</a> (by Zalando, MIT License).</figcaption> </figure>
您可以在 Keras 分类教程中更深入探索此数据集。
构建用于设备端训练的模型
TensorFlow Lite 模型通常只有一个公开的函数方法(或签名),允许您调用模型来运行推断。对于要在设备端训练和使用的模型,您必须能够执行几个单独的操作,包括训练、推断、保存和恢复模型的功能。您可以通过以下方式启用此功能:首先将 TensorFlow 模型扩展为具有多个函数,然后在将模型转换为 TensorFlow Lite 模型格式时将这些函数作为签名公开。
下面的代码示例展示了如何向 TensorFlow 模型添加以下函数:
train
函数用训练数据训练模型。infer
函数调用推断。save
函数将可训练权重保存到文件系统中。restore
函数从文件系统加载可训练权重。
上述代码中的 train
函数使用 GradientTape 类记录操作,以进行自动微分。有关如何使用此类的更多信息,请参阅梯度和自动微分简介。
在这里,您可以使用 Keras 模型的 Model.train_step
方法,而不是从头开始实现。只需注意,Model.train_step
返回的损失(和指标)是运行平均值,应该定期重置(通常是每个周期)。有关详细信息,请参阅自定义 Model.fit。
注:此模型生成的权重被序列化为 TensorFlow 1 格式的检查点文件。
准备数据
获取 Fashion MNIST 数据集以训练您的模型。
预处理数据集
此数据集中的像素值介于 0 和 255 之间,必须归一化为介于 0 和 1 之间的值才能由模型进行处理。将这些值除以 255 即可进行此调整。
通过执行独热编码将数据标签转换为分类值。
注:请确保以相同的方式对您的训练数据集和测试数据集进行预处理,以便您的测试准确评估模型的性能。
训练模型。
在转换和设置 TensorFlow Lite 模型之前,请使用经过预处理的数据集和 train
签名方法完成模型的初始训练。以下代码会运行 100 个周期的模型训练,一次处理 100 个图像批次,并在每 10 个周期之后显示损失值。由于本次训练运行要处理相当多的数据,因此可能需要几分钟才能完成。
注:您应该在将模型转换为 TensorFlow Lite 格式之前完成模型的初始训练,以便模型具有初始权重集,并且能够在开始收集数据并在设备上进行训练运行之前执行合理的推断。
将模型转换为 TensorFlow Lite 格式
扩展 TensorFlow 模型以启用设备端训练的附加功能并完成模型的初始训练后,可以将其转换为 TensorFlow Lite 格式。以下代码会将您的模型转换并保存为该格式,包括您在设备端与 TensorFlow Lite 模型一起使用的签名集:train, infer, save, restore
。
设置 TensorFlow Lite 签名
您在上一步中保存的 TensorFlow Lite 模型包含几个函数签名。您可以通过 tf.lite.Interpreter
类访问它们,并分别调用每个 restore
、train
、save
和 infer
签名。
比较原始模型和转换后的精简模型的输出:
在上面,您可以看到模型的行为不会因为转换为 TFLite 而改变。
在设备端重新训练模型
在将模型转换为 TensorFlow Lite 并将其用您的应用进行部署后,您可以使用新数据和模型的 train
签名方法在设备端重新训练模型。每次训练运行都会生成一组新的权重,您可以保存这些权重以供重用和进一步改进模型,如下一部分所示。
注:由于训练任务是资源密集型任务,您应该考虑在用户不主动与设备交互时执行这些任务,并将其作为后台进程。请考虑使用 WorkManager API 将模型重新训练安排为异步任务。
在 Android 上,您可以使用 Java 或 C++API 使用TensorFlow Lite 执行设备端训练。在 Java 中,使用 Interpreter
类加载模型并驱动模型训练任务。以下示例展示了如何使用 runSignature
方法运行训练过程:
您可以在模型个性化演示应用中看到 Android 应用内模型重新训练的完整代码示例。
运行几个周期的训练,以改进或个性化模型。在实践中,您将使用在设备端收集的数据来运行此附加训练。为简单起见,本例使用与上一个训练步骤相同的训练数据。
您可以从上面看到,设备端训练正好在预训练停止的地方开始。
保存训练后的权重
当您在设备端完成训练运行后,模型会更新它在内存中使用的权重集。使用您在 TensorFlow Lite 模型中创建的 save
签名方法,您可以将这些权重保存到检查点文件中以供以后重用,并改进您的模型。
在您的 Android 应用中,您可以将生成的权重作为检查点文件存储在为您的应用分配的内部存储空间中。
恢复训练后的权重
每当您从 TFLite 模型创建解释器时,解释器都会首先加载原始的模型权重。
因此,在完成一些训练并保存检查点文件后,您将需要运行 restore
签名方法来加载检查点。
一个良好的规则是“每当为模型创建解释器时,如果检查点存在,就加载该检查点”。如果需要将模型重置为基线行为,只需删除检查点并创建新的解释器。
检查点是通过使用 TFLite 进行训练和保存而生成的。在上面您可以看到,应用检查点会更新模型的行为。
注:根据模型中变量的数量和检查点文件的大小,从检查点加载保存的权重可能需要一些时间。
在您的 Android 应用中,您可以从之前存储的检查点文件中恢复序列化的、经过训练的权重。
注:当您的应用重新启动时,您应该在运行新的推断之前重新加载训练后的权重。
使用训练后的权重运行推断
从检查点文件加载以前保存的权重后,运行 infer
方法将这些权重与原始模型一起使用,以改进预测。加载保存的权重后,可以使用 infer
签名方法,如下图所示。
注:加载保存的权重并不是运行推断所必需的,但以该配置运行会使用最初训练的模型生成预测,而不会进行改进。
绘制预测的标签。
在您的 Android 应用中,在恢复训练的权重后,根据加载的数据运行推断。