Path: blob/master/site/zh-cn/federated/tutorials/sparse_federated_learning.ipynb
25118 views
Copyright 2021 The TensorFlow Federated Authors.
通过 federated_select
和稀疏聚合实现客户端高效大型模型联合学习
本教程将展示如何使用 TFF 训练非常大的模型,其中每个客户端设备只下载和更新模型的一小部分,我们将使用 tff.federated_select
和稀疏聚合。虽然本教程基本已涵盖了所有知识背景,但您也可以参阅 tff.federated_select
教程和自定义 FL 算法教程,其中很好地介绍了本文中使用的一些技术。
具体而言,我们在本教程中将考虑采用逻辑回归处理多标记分类,基于词袋特征表示来预测哪些“标记”与文本字符串相关联。重要的是,通信和客户端计算成本会由一个固定常量 (MAX_TOKENS_SELECTED_PER_CLIENT
) 控制,并且成本不会随整体词汇表大小(在实际环境中可能会非常大)而提高。
每个客户端都将 federated_select
(联合选择)模型权重的行,最多达到此数量的唯一词例。这是客户端本地模型大小以及所执行服务器 -> 客户端 (federated_select
) 和客户端 -> 服务器 (federated_aggregate
) 通信量的上限。
即使您将此项设置为低至 1 的值(确保不会每个客户端的所有词例都被选择)或设置为较大值,尽管可能会影响模型收敛,但本教程仍然应当可以正确运行。
我们还为各种类型定义一些常量。对于此 colab,词例是解析数据集后特定单词的整数标识符。
设置问题:数据集和模型
我们构造了一个小型数据集,以在本教程中进行简单的实验。但是,数据集的格式兼容联合 StackOverflow,预处理和模型架构取自 Adaptive Federated Optimization 的 StackOverflow 标记预测问题。
数据集解析和预处理
小型数据集
我们构造了一个小型数据集,其中具有包含 12 个单词的全局词汇表以及 3 个客户端。这个小型示例对于测试边缘案例(例如,我们采用两个不同词例数目小于 MAX_TOKENS_SELECTED_PER_CLIENT = 6
的客户端,以及一个大于该值的客户端)和开发代码而言非常实用。
然而,这种方式的实际用例将为包含数千万或更多条词汇的全局词汇表,每个客户端上都可能会出现数千个不同的词例。因为数据的格式是相同的,扩展到更现实的测试平台问题(例如 tff.simulation.datasets.stackoverflow.load_data()
数据集)应非常简单。
首先,我们将定义我们的单词和标记词汇表。
现在,我们使用小型本地数据集创建 3 个客户端。如果您在 colab 中运行本教程,使用“在标签页中镜像单元”功能来固定此单元及其输出以解释/检查下面开发的函数的输出可能会非常实用。
为原始数量的输入特征(词例/单词)和标签(发布标记)定义常量。我们的实际输入/输出空间大 NUM_OOV_BUCKETS = 1
,因为我们添加了一个 OOV 词例/标记。
创建数据集的批处理版本以及各个单独批次,这在我们随后测试代码时将非常实用。
定义具有稀疏输入的模型
我们对每个标记使用一个简单的独立逻辑回归模型。
让我们确保它有效,首先进行预测:
以及一些简单的集中训练:
联合计算的构建块
我们将实现一个简单版本的联合平均算法,关键区别在于每个设备均仅下载模型的相关子集,并且仅更新该子集。
我们使用 M
作为 MAX_TOKENS_SELECTED_PER_CLIENT
的简化表示形式。概括来讲,一轮训练会涉及以下步骤:
每个参与的客户端都会扫描其本地数据集,解析输入字符串并将它们映射到正确的词例(整数索引)。这需要访问全局(大型)字典(使用特征哈希技术有可能避免这种情况)。然后,我们稀疏计算每个词例出现的次数。如果设备上出现
U
个唯一词例,我们选择num_actual_tokens = min(U, M)
个最频繁的词例进行训练。客户端使用
federated_select
从服务器检索num_actual_tokens
所选词例的模型系数。每个模型切片均为形状为(TAG_VOCAB_SIZE, )
的张量,因此传输到客户端的数据总量最大为TAG_VOCAB_SIZE * M
(请参见下文注释)。客户端构造映射
global_token -> local_token
,其中本地词例(整数索引)是全局词例在所选词例列表中的索引。客户端使用全局模型的“小”版本,最多只有
M
个词例的系数,范围为[0, num_actual_tokens)
。global -> local
映射用于从所选模型切片初始化此模型的密集参数。客户端使用 SGD 基于使用
global -> local
映射预处理的数据训练其本地模型。客户端使用
local -> global
映射将其本地模型的参数转换为IndexedSlices
更新以索引行。服务器使用稀疏和聚合来聚合这些更新。服务器接受上述聚合的(密集)结果,将其除以参与的客户端数量,并将生成的平均更新应用于全局模型。
在本部分中,我们将为这些步骤构造构建块,然后将它们组合在最终的 federated_computation
中,用于捕获一轮训练的完整逻辑。
注:上面的介绍隐藏了一个技术细节:
federated_select
和本地模型的构造都需要静态已知形状,因此我们不能使用动态的按客户端num_actual_tokens
大小。相反,我们使用静态值M
,在需要的地方添加填充。这不会影响算法的语义。
计算客户端词例并决定哪些模型切片进行 federated_select
每个设备都需要决定模型的哪些“切片”与其本地训练数据集相关。对于我们的问题,我们通过(稀疏!)计算有多少样本包含客户端训练数据集中的每个词例。
我们将选择与设备上出现频率最高的 MAX_TOKENS_SELECTED_PER_CLIENT
个词例相对应的模型参数。如果设备上出现的词例少于该数量,我们将填充列表以便可以使用 federated_select
。
请注意,诸如随机选择词例(也许基于其出现概率)等其他策略可能会更好。这将确保模型的所有切片(客户端具有数据)都有机会得到更新。
将全局词例映射到本地词例
上面的选择为我们提供了 [0, actual_num_tokens)
区间内的一组密集词例,我们将用于设备端模型。但是,我们读取的数据集具有的词例来自更大型的全局词汇表区间 [0, WORD_VOCAB_SIZE)
。
因此,我们需要将全局词例映射到其对应的本地词例。本地词例 ID 由上一步中计算的 selected_tokens
张量的索引简单给定。
在每个客户端上训练本地(子)模型
请注意,federated_select
将以与选择键值相同的顺序将所选切片作为 tf.data.Dataset
返回。因此,我们首先定义一个效用函数来接受此类数据集并将其转换为单个密集张量,可用作客户端模型的模型权重。
现在,我们已拥有定义将在每个客户端上运行的简单本地训练循环所需的全部组件。
聚合 IndexedSlices
我们将使用 tff.federated_aggregate
为 IndexedSlices
构造联合稀疏和。这个简单的实现具有一项约束,即 dense_shape
为提前静态已知。另请注意,从客户端 -> 服务器通信为稀疏的角度而言,此和仅为半稀疏,但服务器会在 accumulate
和 merge
中维护和的密集表示,并会输出这个密集表示。
构造最小型 federated_computation
作为测试
全部置于 federated_computation
中
现在,我们使用 TFF 将组件绑定到 tff.federated_computation
中。
我们将使用基于联合平均的基本服务器训练函数,以 1.0 的服务器学习率应用更新。重要的是我们将对模型应用更新(增量),而非简单地对客户端提供的模型求平均值,否则如果模型的给定切片在给定轮次未用于任何客户端训练,其系数可能会归零。
我们还需要几个 tff.tf_computation
组件:
现在,我们已准备好将各部分组合到一起!
让我们训练模型!
现在,我们已拥有训练函数,我们来试用。