Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/zh-cn/federated/tutorials/loading_remote_data.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

在 TFF 中加载远程数据


:本 Colab 已通过验证,可与最新发布版本tensorflow_federated pip 软件包一起使用,但 Tensorflow Federated 项目仍处于预发布开发阶段,可能无法在 main 上运行。

在联合学习的实际应用中,原始训练数据通常分布在许多设备或数据孤岛中 – 需要特殊的预处理和加载才能使用。

本教程介绍了如何使用 TFF 的 DataBackendDataExecutor 接口加载存储在这些远程位置的样本,并借助它们来使用联合学习训练模型。我们将通过使用存储在本地的训练数据集来演示数据加载 API 的用法,并模拟样本的采样,就好像数据集在不同的远程客户端上进行了分区一样。当您根据您的用例调整本教程时,您只需将该数据集与您自己的分布式数据交换即可。

如果您不熟悉联合学习或 TFF,请考虑阅读图像分类的联合学习作为入门读物。

准备工作

在开始之前,请运行以下命令来确保您的环境已正确设置。有关详情,请参阅安装指南。

#@title Set up open-source environment #@test {"skip": true} !pip install --quiet --upgrade tensorflow-federated !pip install --quiet --upgrade nest-asyncio import nest_asyncio nest_asyncio.apply()
#@title Import packages import collections import random from typing import Any, List, Optional, Sequence import numpy as np import tensorflow as tf import tensorflow_federated as tff np.random.seed(0)

准备输入数据

让我们首先从内置仓库中加载 TFF 的 EMNIST 数据集的联合版本:

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

构造一个预处理函数来转换 EMNIST 数据集中的原始样本。

NUM_EPOCHS = 5 SHUFFLE_BUFFER = 100 def preprocess(dataset): def map_fn(element): # Rename the features from `pixels` and `label`, to `x` and `y` for use with # Keras. return collections.OrderedDict( # Transform each `28x28` image into a `784`-element array. x=tf.reshape(element['pixels'], [-1, 784]), y=tf.reshape(element['label'], [-1, 1])) # Shuffle the individual examples and `repeat` over several epochs. return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).map(map_fn)

我们来验证一下这是否有效:

# The local dataset corresponding to a single client as tf.data.Dataset. example_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[0]) preprocessed_example_dataset = preprocess(example_dataset) print(preprocessed_example_dataset)
<MapDataset element_spec=OrderedDict([('x', TensorSpec(shape=(1, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))])>

接下来,我们将构造一个 DataBackend 的实现,它将加载和预处理来自 EMNIST 数据集中客户端的本地样本,这对于在联合学习期间获取可训练样本至关重要。

定义如何获取客户端数据

我们需要一个 DataBackend 的实例来指示 TFF 工作进程如何加载和转换本地数据。

TFF 工作进程是在边缘机器上运行并为单个或多个逻辑客户端执行工作的进程。在此示例中,我们将用于训练的 EMNIST 数据集已经被逻辑客户端分区,并且所有工作进程都将在同一个本地环境中运行。因此,我们的 DataBackend 可以引用任何客户端对应的数据。但在非实验性设置中,TFF 工作进程将分布在各个远程计算机上,每台计算机都映射到一组不同的客户端,您需要确保 DataBackend 可以根据其本地上下文正确解析数据引用。

# A `DataBackend` is a programmatic construct that resolves symbolic references, # represented as application-specific URIs, to materialized examples that # TFF operations can process. class MyDataBackend(tff.framework.DataBackend): async def materialize(self, data, type_spec): # In this example, the URI contains the Id of a client. client_id = int(data.uri[-1]) # The client Id is used to retrieve the corresponding local data. client_dataset = emnist_train.create_tf_dataset_for_client( emnist_train.client_ids[client_id]) # We process the client dataset before returning so its compatible with our # model definitions. return preprocess(client_dataset)

设置运行时环境

TFF 计算由 ExecutionContext 调用,为了在运行时理解 TFF 计算中定义的数据 URI,必须为 TFF 工作进程定义一个自定义上下文,其中包含指向我们刚刚创建的 DataBackend 的指针,以便可以正确解析 URI。

def ex_fn(device: tf.config.LogicalDevice) -> tff.framework.DataExecutor: # A `DataBackend` object is wrapped by a `DataExecutor`, which queries the # backend when a TFF worker encounters an operation requires fetching local # data. return tff.framework.DataExecutor( tff.framework.EagerTFExecutor(device), data_backend=MyDataBackend()) # In a distributed setting, this needs to run in the TFF worker as a service # connecting to some port. The top-level controller feeding TFF computations # would then connect to this port. factory = tff.framework.local_executor_factory(leaf_executor_fn=ex_fn) ctx = tff.framework.ExecutionContext(executor_fn=factory) tff.framework.set_default_context(ctx)

训练模型

现在,我们已经准备好使用联合学习来训练模型。我们来定义一个 Keras 模型:

def create_keras_model(): return tf.keras.models.Sequential([ tf.keras.layers.InputLayer(input_shape=(784,)), tf.keras.layers.Dense(10, kernel_initializer='zeros'), tf.keras.layers.Softmax(), ]) def model_fn(): keras_model = create_keras_model() return tff.learning.from_keras_model( keras_model, input_spec=preprocessed_example_dataset.element_spec, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

我们可以通过调用辅助函数 tff.learning.algorithms.build_weighted_fed_avg 将模型的这个 TFF 包装的定义传递给联合平均算法,如下所示:

iterative_process = tff.learning.algorithms.build_weighted_fed_avg( model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02), server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)) state = iterative_process.initialize()

initialize 计算会返回联合平均过程的初始状态。

为了运行一轮训练,我们需要通过收集 URI 引用样本来构造数据样本,如下所示:

NUM_CLIENTS = 10 element_type = tff.types.StructWithPythonType( preprocessed_example_dataset.element_spec, container_type=collections.OrderedDict) dataset_type = tff.types.SequenceType(element_type) round_data_uris = [f'uri://{i}' for i in range(NUM_CLIENTS)] round_train_data = tff.framework.CreateDataDescriptor( arg_uris=round_data_uris, arg_type=dataset_type)

现在,我们可以进行一轮训练:

result = iterative_process.next(state, round_train_data) state = result.state metrics = result.metrics print('round 1, metrics={}'.format(metrics))
round 1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.11234568), ('loss', 11.965633), ('num_examples', 4860), ('num_batches', 4860)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])

多轮训练

我们可以定义一个 FederatedDataSource 容器来选择客户端并组装输入以检索本地数据。这样一来,循环多轮训练变得十分方便,并且可以在多个训练作业中重复使用。

class MyFederatedDataSourceIterator(tff.program.FederatedDataSourceIterator): def __init__(self, client_ids: Sequence[str], federated_type: tff.FederatedType): self._client_ids = client_ids self._federated_type = federated_type @property def federated_type(self) -> tff.FederatedType: return self._federated_type def select(self, number_of_clients: Optional[int] = None) -> Any: client_ids_sample = random.sample(self._client_ids, number_of_clients) data_uris = [f'uri://{i}' for i in client_ids_sample] return tff.framework.CreateDataDescriptor( arg_uris=data_uris, arg_type=self._federated_type) class MyFederatedDataSource(tff.program.FederatedDataSource): def __init__(self, client_ids: Sequence[str], federated_type: tff.FederatedType): self._client_ids = client_ids self._federated_type = federated_type self._capabilities = [tff.program.Capability.RANDOM_UNIFORM] @property def federated_type(self) -> tff.FederatedType: return self._federated_type @property def capabilities(self) -> List[tff.program.Capability]: return self._capabilities def iterator(self) -> tff.program.FederatedDataSourceIterator: return MyFederatedDataSourceIterator(self._client_ids, self._federated_type) train_data_source = MyFederatedDataSource( client_ids=emnist_train.client_ids, federated_type=dataset_type) train_data_iterator = train_data_source.iterator()

现在,我们可以按如下方式运行联合学习训练循环:

NUM_ROUNDS = 10 for round_num in range(2, NUM_ROUNDS + 1): round_train_data = train_data_iterator.select(NUM_CLIENTS) result = iterative_process.next(state, round_train_data) state = result.state metrics = result.metrics print('round {:2d}, metrics={}'.format(round_num, metrics))
round 2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12357217), ('loss', 9.161968), ('num_examples', 4815), ('num_batches', 4815)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.20563674), ('loss', 7.0862083), ('num_examples', 4790), ('num_batches', 4790)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.30241227), ('loss', 5.6945825), ('num_examples', 4560), ('num_batches', 4560)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.3867347), ('loss', 4.7210026), ('num_examples', 4900), ('num_batches', 4900)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.42311886), ('loss', 4.205554), ('num_examples', 4585), ('num_batches', 4585)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.4501548), ('loss', 4.1297464), ('num_examples', 4845), ('num_batches', 4845)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.56590474), ('loss', 2.8927681), ('num_examples', 5250), ('num_batches', 5250)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.59917355), ('loss', 2.7431731), ('num_examples', 4840), ('num_batches', 4840)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.5717234), ('loss', 2.9738288), ('num_examples', 4845), ('num_batches', 4845)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])

结论

本教程到此结束。我们鼓励您探索我们开发的其他教程,以了解 TFF 框架的许多其他功能。