Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/zh-cn/federated/tutorials/sparse_federated_learning.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

通过 federated_select 和稀疏聚合实现客户端高效大型模型联合学习

本教程将展示如何使用 TFF 训练非常大的模型,其中每个客户端设备只下载和更新模型的一小部分,我们将使用 tff.federated_select 和稀疏聚合。虽然本教程基本已涵盖了所有知识背景,但您也可以参阅 tff.federated_select 教程自定义 FL 算法教程,其中很好地介绍了本文中使用的一些技术。

具体而言,我们在本教程中将考虑采用逻辑回归处理多标记分类,基于词袋特征表示来预测哪些“标记”与文本字符串相关联。重要的是,通信和客户端计算成本会由一个固定常量 (MAX_TOKENS_SELECTED_PER_CLIENT) 控制,并且成本不会随整体词汇表大小(在实际环境中可能会非常大)而提高。

#@test {"skip": true} !pip install --quiet --upgrade tensorflow-federated
import collections from collections.abc import Callable import itertools import numpy as np import tensorflow as tf import tensorflow_federated as tff

每个客户端都将 federated_select(联合选择)模型权重的行,最多达到此数量的唯一词例。这是客户端本地模型大小以及所执行服务器 -> 客户端 (federated_select) 和客户端 -> 服务器 (federated_aggregate) 通信量的上限。

即使您将此项设置为低至 1 的值(确保不会每个客户端的所有词例都被选择)或设置为较大值,尽管可能会影响模型收敛,但本教程仍然应当可以正确运行。

MAX_TOKENS_SELECTED_PER_CLIENT = 6

我们还为各种类型定义一些常量。对于此 colab,词例是解析数据集后特定单词的整数标识符。

# There are some constraints on types # here that will require some explicit type conversions: # - `tff.federated_select` requires int32 # - `tf.SparseTensor` requires int64 indices. TOKEN_DTYPE = tf.int64 SELECT_KEY_DTYPE = tf.int32 # Type for counts of token occurences. TOKEN_COUNT_DTYPE = tf.int32 # A sparse feature vector can be thought of as a map # from TOKEN_DTYPE to FEATURE_DTYPE. # Our features are {0, 1} indicators, so we could potentially # use tf.int8 as an optimization. FEATURE_DTYPE = tf.int32

设置问题:数据集和模型

我们构造了一个小型数据集,以在本教程中进行简单的实验。但是,数据集的格式兼容联合 StackOverflow预处理模型架构取自 Adaptive Federated Optimization 的 StackOverflow 标记预测问题。

数据集解析和预处理

NUM_OOV_BUCKETS = 1 BatchType = collections.namedtuple('BatchType', ['tokens', 'tags']) def build_to_ids_fn(word_vocab: list[str], tag_vocab: list[str]) -> Callable[[tf.Tensor], tf.Tensor]: """Constructs a function mapping examples to sequences of token indices.""" word_table_values = np.arange(len(word_vocab), dtype=np.int64) word_table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer(word_vocab, word_table_values), num_oov_buckets=NUM_OOV_BUCKETS) tag_table_values = np.arange(len(tag_vocab), dtype=np.int64) tag_table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer(tag_vocab, tag_table_values), num_oov_buckets=NUM_OOV_BUCKETS) def to_ids(example): """Converts a Stack Overflow example to a bag-of-words/tags format.""" sentence = tf.strings.join([example['tokens'], example['title']], separator=' ') # We represent that label (output tags) densely. raw_tags = example['tags'] tags = tf.strings.split(raw_tags, sep='|') tags = tag_table.lookup(tags) tags, _ = tf.unique(tags) tags = tf.one_hot(tags, len(tag_vocab) + NUM_OOV_BUCKETS) tags = tf.reduce_max(tags, axis=0) # We represent the features as a SparseTensor of {0, 1}s. words = tf.strings.split(sentence) tokens = word_table.lookup(words) tokens, _ = tf.unique(tokens) # Note: We could choose to use the word counts as the feature vector # instead of just {0, 1} values (see tf.unique_with_counts). tokens = tf.reshape(tokens, shape=(tf.size(tokens), 1)) tokens_st = tf.SparseTensor( tokens, tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE), dense_shape=(len(word_vocab) + NUM_OOV_BUCKETS,)) tokens_st = tf.sparse.reorder(tokens_st) return BatchType(tokens_st, tags) return to_ids
def build_preprocess_fn(word_vocab, tag_vocab): @tf.function def preprocess_fn(dataset): to_ids = build_to_ids_fn(word_vocab, tag_vocab) # We *don't* shuffle in order to make this colab deterministic for # easier testing and reproducibility. # But real-world training should use `.shuffle()`. return dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE) return preprocess_fn

小型数据集

我们构造了一个小型数据集,其中具有包含 12 个单词的全局词汇表以及 3 个客户端。这个小型示例对于测试边缘案例(例如,我们采用两个不同词例数目小于 MAX_TOKENS_SELECTED_PER_CLIENT = 6 的客户端,以及一个大于该值的客户端)和开发代码而言非常实用。

然而,这种方式的实际用例将为包含数千万或更多条词汇的全局词汇表,每个客户端上都可能会出现数千个不同的词例。因为数据的格式是相同的,扩展到更现实的测试平台问题(例如 tff.simulation.datasets.stackoverflow.load_data() 数据集)应非常简单。

首先,我们将定义我们的单词和标记词汇表。

# Features FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi'] VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas'] FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon'] WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS # Labels TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']

现在,我们使用小型本地数据集创建 3 个客户端。如果您在 colab 中运行本教程,使用“在标签页中镜像单元”功能来固定此单元及其输出以解释/检查下面开发的函数的输出可能会非常实用。

preprocess_fn = build_preprocess_fn(WORD_VOCAB, TAG_VOCAB) def make_dataset(raw): d = tf.data.Dataset.from_tensor_slices( # Matches the StackOverflow formatting collections.OrderedDict( tokens=tf.constant([t[0] for t in raw]), tags=tf.constant([t[1] for t in raw]), title=['' for _ in raw])) d = preprocess_fn(d) return d # 4 distinct tokens CLIENT1_DATASET = make_dataset([ ('apple orange apple orange', 'FRUIT'), ('carrot trout', 'VEGETABLE|FISH'), ('orange apple', 'FRUIT'), ('orange', 'ORANGE|CITRUS') # 2 OOV tag ]) # 6 distinct tokens CLIENT2_DATASET = make_dataset([ ('pear cod', 'FRUIT|FISH'), ('arugula peas', 'VEGETABLE'), ('kiwi pear', 'FRUIT'), ('sturgeon', 'FISH'), # OOV word ('sturgeon bass', 'FISH') # 2 OOV words ]) # A client with all possible words & tags (13 distinct tokens). # With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model # slices for all tokens that occur on this client. CLIENT3_DATASET = make_dataset([ (' '.join(WORD_VOCAB + ['oovword']), '|'.join(TAG_VOCAB)), # Mathe the OOV token and 'salmon' occur in the largest number # of examples on this client: ('salmon oovword', 'FISH|OOVTAG') ]) print('Word vocab') for i, word in enumerate(WORD_VOCAB): print(f'{i:2d} {word}') print('\nTag vocab') for i, tag in enumerate(TAG_VOCAB): print(f'{i:2d} {tag}')
Word vocab 0 apple 1 orange 2 pear 3 kiwi 4 carrot 5 broccoli 6 arugula 7 peas 8 trout 9 tuna 10 cod 11 salmon Tag vocab 0 FRUIT 1 VEGETABLE 2 FISH

为原始数量的输入特征(词例/单词)和标签(发布标记)定义常量。我们的实际输入/输出空间大 NUM_OOV_BUCKETS = 1,因为我们添加了一个 OOV 词例/标记。

NUM_WORDS = len(WORD_VOCAB) NUM_TAGS = len(TAG_VOCAB) WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS

创建数据集的批处理版本以及各个单独批次,这在我们随后测试代码时将非常实用。

batched_dataset1 = CLIENT1_DATASET.batch(2) batched_dataset2 = CLIENT2_DATASET.batch(3) batched_dataset3 = CLIENT3_DATASET.batch(2) batch1 = next(iter(batched_dataset1)) batch2 = next(iter(batched_dataset2)) batch3 = next(iter(batched_dataset3))

定义具有稀疏输入的模型

我们对每个标记使用一个简单的独立逻辑回归模型。

def create_logistic_model(word_vocab_size: int, vocab_tags_size: int): model = tf.keras.models.Sequential([ tf.keras.layers.InputLayer(input_shape=(word_vocab_size,), sparse=True), tf.keras.layers.Dense( vocab_tags_size, activation='sigmoid', kernel_initializer=tf.keras.initializers.zeros, # For simplicity, don't use a bias vector; this means the model # is a single tensor, and we only need sparse aggregation of # the per-token slices of the model. Generalizing to also handle # other model weights that are fully updated # (non-dense broadcast and aggregate) would be a good exercise. use_bias=False), ]) return model

让我们确保它有效,首先进行预测:

model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE) p = model.predict(batch1.tokens) print(p)
[[0.5 0.5 0.5 0.5] [0.5 0.5 0.5 0.5]]

以及一些简单的集中训练:

model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001), loss=tf.keras.losses.BinaryCrossentropy()) model.train_on_batch(batch1.tokens, batch1.tags)

联合计算的构建块

我们将实现一个简单版本的联合平均算法,关键区别在于每个设备均仅下载模型的相关子集,并且仅更新该子集。

我们使用 M 作为 MAX_TOKENS_SELECTED_PER_CLIENT 的简化表示形式。概括来讲,一轮训练会涉及以下步骤:

  1. 每个参与的客户端都会扫描其本地数据集,解析输入字符串并将它们映射到正确的词例(整数索引)。这需要访问全局(大型)字典(使用特征哈希技术有可能避免这种情况)。然后,我们稀疏计算每个词例出现的次数。如果设备上出现 U 个唯一词例,我们选择 num_actual_tokens = min(U, M) 个最频繁的词例进行训练。

  2. 客户端使用 federated_select 从服务器检索 num_actual_tokens 所选词例的模型系数。每个模型切片均为形状为 (TAG_VOCAB_SIZE, ) 的张量,因此传输到客户端的数据总量最大为 TAG_VOCAB_SIZE * M(请参见下文注释)。

  3. 客户端构造映射 global_token -> local_token,其中本地词例(整数索引)是全局词例在所选词例列表中的索引。

  4. 客户端使用全局模型的“小”版本,最多只有 M 个词例的系数,范围为 [0, num_actual_tokens)global -> local 映射用于从所选模型切片初始化此模型的密集参数。

  5. 客户端使用 SGD 基于使用 global -> local 映射预处理的数据训练其本地模型。

  6. 客户端使用 local -> global 映射将其本地模型的参数转换为 IndexedSlices 更新以索引行。服务器使用稀疏和聚合来聚合这些更新。

  7. 服务器接受上述聚合的(密集)结果,将其除以参与的客户端数量,并将生成的平均更新应用于全局模型。

在本部分中,我们将为这些步骤构造构建块,然后将它们组合在最终的 federated_computation 中,用于捕获一轮训练的完整逻辑。

注:上面的介绍隐藏了一个技术细节:federated_select 和本地模型的构造都需要静态已知形状,因此我们不能使用动态的按客户端 num_actual_tokens 大小。相反,我们使用静态值 M,在需要的地方添加填充。这不会影响算法的语义。

计算客户端词例并决定哪些模型切片进行 federated_select

每个设备都需要决定模型的哪些“切片”与其本地训练数据集相关。对于我们的问题,我们通过(稀疏!)计算有多少样本包含客户端训练数据集中的每个词例。

@tf.function def token_count_fn(token_counts, batch): """Adds counts from `batch` to the running `token_counts` sum.""" # Sum across the batch dimension. flat_tokens = tf.sparse.reduce_sum( batch.tokens, axis=0, output_is_sparse=True) flat_tokens = tf.cast(flat_tokens, dtype=TOKEN_COUNT_DTYPE) return tf.sparse.add(token_counts, flat_tokens)
# Simple tests # Create the initial zero token counts using empty tensors. initial_token_counts = tf.SparseTensor( indices=tf.zeros(shape=(0, 1), dtype=TOKEN_DTYPE), values=tf.zeros(shape=(0,), dtype=TOKEN_COUNT_DTYPE), dense_shape=(WORD_VOCAB_SIZE,)) client_token_counts = batched_dataset1.reduce(initial_token_counts, token_count_fn) tokens = tf.reshape(client_token_counts.indices, (-1,)).numpy() print('tokens:', tokens) np.testing.assert_array_equal(tokens, [0, 1, 4, 8]) # The count is the number of *examples* in which the token/word # occurs, not the total number of occurences, since we still featurize # multiple occurences in the same example as a "1". counts = client_token_counts.values.numpy() print('counts:', counts) np.testing.assert_array_equal(counts, [2, 3, 1, 1])
tokens: [0 1 4 8] counts: [2 3 1 1]

我们将选择与设备上出现频率最高的 MAX_TOKENS_SELECTED_PER_CLIENT 个词例相对应的模型参数。如果设备上出现的词例少于该数量,我们将填充列表以便可以使用 federated_select

请注意,诸如随机选择词例(也许基于其出现概率)等其他策略可能会更好。这将确保模型的所有切片(客户端具有数据)都有机会得到更新。

@tf.function def keys_for_client(client_dataset, max_tokens_per_client): """Computes a set of max_tokens_per_client keys.""" initial_token_counts = tf.SparseTensor( indices=tf.zeros((0, 1), dtype=TOKEN_DTYPE), values=tf.zeros((0,), dtype=TOKEN_COUNT_DTYPE), dense_shape=(WORD_VOCAB_SIZE,)) client_token_counts = client_dataset.reduce(initial_token_counts, token_count_fn) # Find the most-frequently occuring tokens tokens = tf.reshape(client_token_counts.indices, shape=(-1,)) counts = client_token_counts.values perm = tf.argsort(counts, direction='DESCENDING') tokens = tf.gather(tokens, perm) counts = tf.gather(counts, perm) num_raw_tokens = tf.shape(tokens)[0] actual_num_tokens = tf.minimum(max_tokens_per_client, num_raw_tokens) selected_tokens = tokens[:actual_num_tokens] paddings = [[0, max_tokens_per_client - tf.shape(selected_tokens)[0]]] padded_tokens = tf.pad(selected_tokens, paddings=paddings) # Make sure the type is statically determined padded_tokens = tf.reshape(padded_tokens, shape=(max_tokens_per_client,)) # We will pass these tokens as keys into `federated_select`, which # requires SELECT_KEY_DTYPE=tf.int32 keys. padded_tokens = tf.cast(padded_tokens, dtype=SELECT_KEY_DTYPE) return padded_tokens, actual_num_tokens
# Simple test # Case 1: actual_num_tokens > max_tokens_per_client selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 3) assert tf.size(selected_tokens) == 3 assert actual_num_tokens == 3 # Case 2: actual_num_tokens < max_tokens_per_client selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 10) assert tf.size(selected_tokens) == 10 assert actual_num_tokens == 4

将全局词例映射到本地词例

上面的选择为我们提供了 [0, actual_num_tokens) 区间内的一组密集词例,我们将用于设备端模型。但是,我们读取的数据集具有的词例来自更大型的全局词汇表区间 [0, WORD_VOCAB_SIZE)

因此,我们需要将全局词例映射到其对应的本地词例。本地词例 ID 由上一步中计算的 selected_tokens 张量的索引简单给定。

@tf.function def map_to_local_token_ids(client_data, client_keys): global_to_local = tf.lookup.StaticHashTable( # Note int32 -> int64 maps are not supported tf.lookup.KeyValueTensorInitializer( keys=tf.cast(client_keys, dtype=TOKEN_DTYPE), # Note we need to use tf.shape, not the static # shape client_keys.shape[0] values=tf.range(0, limit=tf.shape(client_keys)[0], dtype=TOKEN_DTYPE)), # We use -1 for tokens that were not selected, which can occur for clients # with more than MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens. # We will simply remove these invalid indices from the batch below. default_value=-1) def to_local_ids(sparse_tokens): indices_t = tf.transpose(sparse_tokens.indices) batch_indices = indices_t[0] # First column tokens = indices_t[1] # Second column tokens = tf.map_fn( lambda global_token_id: global_to_local.lookup(global_token_id), tokens) # Remove tokens that aren't actually available (looked up as -1): available_tokens = tokens >= 0 tokens = tokens[available_tokens] batch_indices = batch_indices[available_tokens] updated_indices = tf.transpose( tf.concat([[batch_indices], [tokens]], axis=0)) st = tf.sparse.SparseTensor( updated_indices, tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE), # Each client has at most MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens. dense_shape=[sparse_tokens.dense_shape[0], MAX_TOKENS_SELECTED_PER_CLIENT]) st = tf.sparse.reorder(st) return st return client_data.map(lambda b: BatchType(to_local_ids(b.tokens), b.tags))
# Simple test client_keys, actual_num_tokens = keys_for_client( batched_dataset3, MAX_TOKENS_SELECTED_PER_CLIENT) client_keys = client_keys[:actual_num_tokens] d = map_to_local_token_ids(batched_dataset3, client_keys) batch = next(iter(d)) all_tokens = tf.gather(batch.tokens.indices, indices=1, axis=1) # Confirm we have local indices in the range [0, MAX): assert tf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT assert tf.math.reduce_max(all_tokens) >= 0

在每个客户端上训练本地(子)模型

请注意,federated_select 将以与选择键值相同的顺序将所选切片作为 tf.data.Dataset 返回。因此,我们首先定义一个效用函数来接受此类数据集并将其转换为单个密集张量,可用作客户端模型的模型权重。

@tf.function def slices_dataset_to_tensor(slices_dataset): """Convert a dataset of slices to a tensor.""" # Use batching to gather all of the slices into a single tensor. d = slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT, drop_remainder=False) iter_d = iter(d) tensor = next(iter_d) # Make sure we have consumed everything opt = iter_d.get_next_as_optional() tf.Assert(tf.logical_not(opt.has_value()), data=[''], name='CHECK_EMPTY') return tensor
# Simple test weights = np.random.random( size=(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE)).astype(np.float32) model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(weights) weights2 = slices_dataset_to_tensor(model_slices_as_dataset) np.testing.assert_array_equal(weights, weights2)

现在,我们已拥有定义将在每个客户端上运行的简单本地训练循环所需的全部组件。

@tf.function def client_train_fn(model, client_optimizer, model_slices_as_dataset, client_data, client_keys, actual_num_tokens): initial_model_weights = slices_dataset_to_tensor(model_slices_as_dataset) assert len(model.trainable_variables) == 1 model.trainable_variables[0].assign(initial_model_weights) # Only keep the "real" (unpadded) keys. client_keys = client_keys[:actual_num_tokens] client_data = map_to_local_token_ids(client_data, client_keys) loss_fn = tf.keras.losses.BinaryCrossentropy() for features, labels in client_data: with tf.GradientTape() as tape: predictions = model(features) loss = loss_fn(labels, predictions) grads = tape.gradient(loss, model.trainable_variables) client_optimizer.apply_gradients(zip(grads, model.trainable_variables)) model_weights_delta = model.trainable_weights[0] - initial_model_weights model_weights_delta = tf.slice(model_weights_delta, begin=[0, 0], size=[actual_num_tokens, -1]) return client_keys, model_weights_delta
# Simple test # Note if you execute this cell a second time, you need to also re-execute # the preceeding cell to avoid "tf.function-decorated function tried to # create variables on non-first call" errors. on_device_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE) client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001) client_keys, actual_num_tokens = keys_for_client( batched_dataset2, MAX_TOKENS_SELECTED_PER_CLIENT) model_slices_as_dataset = tf.data.Dataset.from_tensor_slices( np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE), dtype=np.float32)) keys, delta = client_train_fn( on_device_model, client_optimizer, model_slices_as_dataset, client_data=batched_dataset3, client_keys=client_keys, actual_num_tokens=actual_num_tokens) print(delta)

聚合 IndexedSlices

我们将使用 tff.federated_aggregateIndexedSlices 构造联合稀疏和。这个简单的实现具有一项约束,即 dense_shape 为提前静态已知。另请注意,从客户端 -> 服务器通信为稀疏的角度而言,此和仅为半稀疏,但服务器会在 accumulatemerge 中维护和的密集表示,并会输出这个密集表示。

def federated_indexed_slices_sum(slice_indices, slice_values, dense_shape): """ Sums IndexedSlices@CLIENTS to a dense @SERVER Tensor. Intermediate aggregation is performed by converting to a dense representation, which may not be suitable for all applications. Args: slice_indices: An IndexedSlices.indices tensor @CLIENTS. slice_values: An IndexedSlices.values tensor @CLIENTS. dense_shape: A statically known dense shape. Returns: A dense tensor placed @SERVER representing the sum of the client's IndexedSclies. """ slices_dtype = slice_values.type_signature.member.dtype zero = tff.tf_computation( lambda: tf.zeros(dense_shape, dtype=slices_dtype))() @tf.function def accumulate_slices(dense, client_value): indices, slices = client_value # There is no built-in way to add `IndexedSlices`, but # tf.convert_to_tensor is a quick way to convert to a dense representation # so we can add them. return dense + tf.convert_to_tensor( tf.IndexedSlices(slices, indices, dense_shape)) return tff.federated_aggregate( (slice_indices, slice_values), zero=zero, accumulate=tff.tf_computation(accumulate_slices), merge=tff.tf_computation(lambda d1, d2: tf.add(d1, d2, name='merge')), report=tff.tf_computation(lambda d: d))

构造最小型 federated_computation 作为测试

dense_shape = (6, 2) indices_type = tff.TensorType(tf.int64, (None,)) values_type = tff.TensorType(tf.float32, (None, 2)) client_slice_type = tff.type_at_clients( (indices_type, values_type)) @tff.federated_computation(client_slice_type) def test_sum_indexed_slices(indices_values_at_client): indices, values = indices_values_at_client return federated_indexed_slices_sum(indices, values, dense_shape) print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x = tf.IndexedSlices( values=np.array([[2., 2.1], [0., 0.1], [1., 1.1], [5., 5.1]], dtype=np.float32), indices=[2, 0, 1, 5], dense_shape=dense_shape) y = tf.IndexedSlices( values=np.array([[0., 0.3], [3.1, 3.2]], dtype=np.float32), indices=[1, 3], dense_shape=dense_shape) # Sum one. result = test_sum_indexed_slices([(x.indices, x.values)]) np.testing.assert_array_equal(tf.convert_to_tensor(x), result) # Sum two. expected = [[0., 0.1], [1., 1.4], [2., 2.1], [3.1, 3.2], [0., 0.], [5., 5.1]] result = test_sum_indexed_slices([(x.indices, x.values), (y.indices, y.values)]) np.testing.assert_array_almost_equal(expected, result)

全部置于 federated_computation

现在,我们使用 TFF 将组件绑定到 tff.federated_computation 中。

DENSE_MODEL_SHAPE = (WORD_VOCAB_SIZE, TAG_VOCAB_SIZE) client_data_type = tff.SequenceType(batched_dataset1.element_spec) model_type = tff.TensorType(tf.float32, shape=DENSE_MODEL_SHAPE)

我们将使用基于联合平均的基本服务器训练函数,以 1.0 的服务器学习率应用更新。重要的是我们将对模型应用更新(增量),而非简单地对客户端提供的模型求平均值,否则如果模型的给定切片在给定轮次未用于任何客户端训练,其系数可能会归零。

@tff.tf_computation def server_update(current_model_weights, update_sum, num_clients): average_update = update_sum / num_clients return current_model_weights + average_update

我们还需要几个 tff.tf_computation 组件:

# Function to select slices from the model weights in federated_select: select_fn = tff.tf_computation( lambda model_weights, index: tf.gather(model_weights, index)) # We need to wrap `client_train_fn` as a `tff.tf_computation`, making # sure we do any operations that might construct `tf.Variable`s outside # of the `tf.function` we are wrapping. @tff.tf_computation def client_train_fn_tff(model_slices_as_dataset, client_data, client_keys, actual_num_tokens): # Note this is amaller than the global model, using # MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE. # We would like a model of size `actual_num_tokens`, but we # can't build the model dynamically, so we will slice off the padded # weights at the end. client_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE) client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) return client_train_fn(client_model, client_optimizer, model_slices_as_dataset, client_data, client_keys, actual_num_tokens) @tff.tf_computation def keys_for_client_tff(client_data): return keys_for_client(client_data, MAX_TOKENS_SELECTED_PER_CLIENT)

现在,我们已准备好将各部分组合到一起!

@tff.federated_computation( tff.type_at_server(model_type), tff.type_at_clients(client_data_type)) def sparse_model_update(server_model, client_data): max_tokens = tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT, tff.SERVER) keys_at_clients, actual_num_tokens = tff.federated_map( keys_for_client_tff, client_data) model_slices = tff.federated_select(keys_at_clients, max_tokens, server_model, select_fn) update_keys, update_slices = tff.federated_map( client_train_fn_tff, (model_slices, client_data, keys_at_clients, actual_num_tokens)) dense_update_sum = federated_indexed_slices_sum(update_keys, update_slices, DENSE_MODEL_SHAPE) num_clients = tff.federated_sum(tff.federated_value(1.0, tff.CLIENTS)) updated_server_model = tff.federated_map( server_update, (server_model, dense_update_sum, num_clients)) return updated_server_model print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)

让我们训练模型!

现在,我们已拥有训练函数,我们来试用。

server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE) server_model.compile( # Compile to make evaluation easy. optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0), # Unused loss=tf.keras.losses.BinaryCrossentropy(), metrics=[ tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.AUC(name='auc'), tf.keras.metrics.Recall(top_k=2, name='recall_at_2'), ]) def evaluate(model, dataset, name): metrics = model.evaluate(dataset, verbose=0) metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in (zip(server_model.metrics_names, metrics))]) print(f'{name}: {metrics_str}')
print('Before training') evaluate(server_model, batched_dataset1, 'Client 1') evaluate(server_model, batched_dataset2, 'Client 2') evaluate(server_model, batched_dataset3, 'Client 3') model_weights = server_model.trainable_weights[0] client_datasets = [batched_dataset1, batched_dataset2, batched_dataset3] for _ in range(10): # Run 10 rounds of FedAvg # We train on 1, 2, or 3 clients per round, selecting # randomly. cohort_size = np.random.randint(1, 4) clients = np.random.choice([0, 1, 2], cohort_size, replace=False) print('Training on clients', clients) model_weights = sparse_model_update( model_weights, [client_datasets[i] for i in clients]) server_model.set_weights([model_weights]) print('After training') evaluate(server_model, batched_dataset1, 'Client 1') evaluate(server_model, batched_dataset2, 'Client 2') evaluate(server_model, batched_dataset3, 'Client 3')
Before training Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60 Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50 Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40 Training on clients [0 1] Training on clients [0 2 1] Training on clients [2 0] Training on clients [1 0 2] Training on clients [2] Training on clients [2 0] Training on clients [1 2 0] Training on clients [0] Training on clients [2] Training on clients [1 2] After training Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80 Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00 Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80