Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/zh-cn/federated/tutorials/jax_support.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

TFF 中对 JAX 的实验性支持

除了作为 TensorFlow 生态系统的一部分,TFF 还致力于实现与其他前端和后端 ML 框架的互操作性。目前,对其他 ML 框架的支持仍处于孵化阶段,所支持的 API 和功能可能会发生变化(很大程度上取决于 TFF 用户的需求)。本教程描述如何使用 TFF 和 JAX 作为替代的 ML 前端,以及如何使用 XLA 编译器作为替代的后端。这里展示的示例基于一个完全原生的端到端 JAX/XLA 堆栈。跨框架混合代码(例如,JAX 和 TensorFlow)的可能性将在未来的教程中进行讨论。

我们一如既往地欢迎您的贡献。如果对 JAX/XLA 的支持或与其他 ML 框架的互操作能力对您来说很重要,请考虑帮助我们发展这些功能,使其与 TFF 的其他功能持平。

开始之前

关于如何配置您的环境,请参阅 TFF 文档的正文。根据运行本教程的位置,您可能希望取消注释并运行下面的部分或全部代码。

# !pip install --quiet --upgrade tensorflow-federated # !pip install --quiet --upgrade nest-asyncio # import nest_asyncio # nest_asyncio.apply()

本教程还假设您已经阅读了 TFF 的主要 TensorFlow 教程,并且熟悉 TFF 核心概念。如果您还没有阅读这些内容,请考虑至少审阅其中一个。

JAX 计算

在 TFF 中对 JAX 的支持被设计成与 TFF 与 TensorFlow 互操作的方式对称,从导入开始:

import jax import numpy as np import tensorflow_federated as tff

此外,与 TensorFlow 一样,表达任何 TFF 代码的基础是本地运行的逻辑。如下所示,您可以使用 @tff.jax_computation 封装容器在 JAX 中表达该逻辑。它的行为类似于您现在所熟悉的 @tff.tf_computation 。我们先从简单的内容开始,例如,将两个整数相加的计算:

@tff.jax_computation(np.int32, np.int32) def add_numbers(x, y): return jax.numpy.add(x, y)

您可以像通常使用 TFF计 算一样使用上面定义的 JAX 计算。例如,您可以查看其类型签名,如下所示:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

请注意,我们使用了 np.int32 来定义参数类型。TFF 不区分 Numpy 类型(如 np.int32)和 TensorFlow 类型(如 tf.int32)。从 TFF 的角度来看,它们只是指代同一事物的不同方式。

接下来,请记住 TFF 不是 Python(如果您不熟悉此内容,请查看我们之前的部分教程,如,有关自定义算法的内容)。您可以将 @tff.jax_computation 封装容器与任意可以跟踪和序列化的 JAX 代码一起使用,即您通常会使用 @jax.jit 进行注解并应被编译成 XLA 的代码(但您无需真的使用 @jax.jit 注解将 JAX 代码嵌入到 TFF 中)。

实际上,TFF 会在底层立即将 JAX 计算编译成 XLA。您可以通过手动从 add_numbers 中提取和打印序列化的 XLA 代码来亲自查看,如下所示:

comp_pb = tff.framework.serialize_computation(add_numbers) comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value) print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7 ENTRY xla_computation_add_numbers.7 { constant.4 = pred[] constant(false) parameter.1 = (s32[], s32[]) parameter(0) get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0 get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1 add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3) ROOT tuple.6 = (s32[]) tuple(add.5) }

对于在 TensorFlow 表达的计算,可以将表示为 XLA 代码的 JAX 计算看作 tf.GraphDef 的功能对等项。它可移植并可在各种支持 XLA 的环境中执行,就像 tf.GraphDef 可以在任何 TensorFlow 运行时上执行一样。

TFF 提供了一个基于 XLA 编译器的运行时堆栈作为后端。可以通过以下方式激活:

tff.backends.xla.set_local_python_execution_context()

现在,您可以执行我们上面定义的计算:

add_numbers(2, 3)
5

很简单。我们来继续做一些更复杂的事情,比如 MNIST。

使用预设 API 的 MNIST 训练示例

像往常一样,我们首先为数据批次和模型定义一组 TFF 类型(请记住,TFF 是一个强类型框架)。

import collections BATCH_TYPE = collections.OrderedDict([ ('pixels', tff.TensorType(np.float32, (50, 784))), ('labels', tff.TensorType(np.int32, (50,))) ]) MODEL_TYPE = collections.OrderedDict([ ('weights', tff.TensorType(np.float32, (784, 10))), ('bias', tff.TensorType(np.float32, (10,))) ])

现在,我们以模型和单批数据为参数,在 JAX 中为模型定义一个损失函数:

def loss(model, batch): y = jax.nn.softmax( jax.numpy.add( jax.numpy.matmul(batch['pixels'], model['weights']), model['bias'])) targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10) return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

现在,一种方法是使用预设 API。
下面是一个示例,演示如何使用我们的 API 根据刚才定义的损失函数创建训练流程。

STEP_SIZE = 0.001 trainer = tff.learning.build_jax_federated_averaging_process( BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

您可以像使用 TensorFlow 中的 tf.Keras 模型的 trainer 构建一样,使用上面的代码。例如,以下是如何为训练创建初始模型的方法:

initial_model = trainer.initialize() initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

为了进行实际训练,我们需要一些数据。简单起见,我们随机生成数据。由于数据是随机的,我们将对训练数据进行评估,否则,对于随机评估数据,很难期望模型正常执行。此外,对于这个小规模演示,我们无需担心随机采样的客户端(我们会将其作为练习,让用户按照其他教程中的模板来探索这些类型的变化):

def random_batch(): pixels = np.random.uniform( low=0.0, high=1.0, size=(50, 784)).astype(np.float32) labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32) return collections.OrderedDict([('pixels', pixels), ('labels', labels)]) NUM_CLIENTS = 2 NUM_BATCHES = 10 train_data = [ [random_batch() for _ in range(NUM_BATCHES)] for _ in range(NUM_CLIENTS)]

有了这些准备工作,我们可以执行单个步骤的训练,如下所示:

trained_model = trainer.next(initial_model, train_data) trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05, 2.54597180e-05, ..., 5.61640409e-05, -5.32875274e-05, -4.62881755e-04], [ 7.30908650e-05, 4.67643113e-05, 2.03352147e-06, ..., 3.77510623e-05, 3.52839161e-05, -4.59865667e-04], [ 8.14835730e-05, 3.03147244e-05, -1.89143739e-05, ..., 1.12527239e-04, 4.09212225e-06, -4.59960109e-04], ..., [ 9.23552434e-05, 2.44302555e-06, -2.20817346e-05, ..., 7.61375341e-05, 1.76906979e-05, -4.43495519e-04], [ 1.17451040e-04, 2.47748958e-05, 1.04728279e-05, ..., 5.26388249e-07, 7.21131510e-05, -4.67137404e-04], [ 3.75041491e-05, 6.58061981e-05, 1.14522081e-05, ..., 2.52584141e-05, 3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04, 2.6502126e-05, -1.9462314e-05, 8.1269856e-05, 2.1832302e-04, 1.6636557e-04, 1.2815947e-04, 9.0642272e-05, 7.7109929e-05, -9.1987278e-04], dtype=float32))])

我们来评估一下训练步骤的结果。简单起见,我们可以对其进行集中评估:

import itertools eval_data = list(itertools.chain.from_iterable(train_data)) def average_loss(model, data): return np.mean([loss(model, batch) for batch in data]) print (average_loss(initial_model, eval_data)) print (average_loss(trained_model, eval_data))
2.3025854 2.282762

损失正在减少。太棒了!现在,我们来多运行几轮:

NUM_ROUNDS = 20 for _ in range(NUM_ROUNDS): trained_model = trainer.next(trained_model, train_data) print(average_loss(trained_model, eval_data))
2.2685437 2.257856 2.2495182 2.2428129 2.2372835 2.2326245 2.2286277 2.2251441 2.2220676 2.219318 2.2168345 2.2145717 2.2124937 2.2105706 2.2087805 2.2071042 2.2055268 2.2040353 2.2026198 2.2012706

如您所见,尽管实验性 API 在功能上还不能与 TensorFlow API 相提并论,将 JAX 与 TFF 配合使用并没有太大的不同。

在底层

如果您不喜欢使用我们的预设 API,您可以实现您自己的自定义计算。所用方法与您在 TensorFlow 的自定义算法教程中看到的大致相同,只是您将使用 JAX 的梯度下降机制。例如,下面是定义能够在单个 mini-batch 上更新模型的 JAX 计算的方法:

@tff.jax_computation(MODEL_TYPE, BATCH_TYPE) def train_on_one_batch(model, batch): grads = jax.grad(loss)(model, batch) return collections.OrderedDict([ (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias'] ])

下面是测试其是否能够正常工作的方法:

sample_batch = random_batch() trained_model = train_on_one_batch(initial_model, sample_batch) print(average_loss(initial_model, [sample_batch])) print(average_loss(trained_model, [sample_batch]))
2.3025854 2.2977567

使用 JAX 需要注意的一点是,它不提供与 tf.data.Dataset 相同的功能。因此,为了迭代数据集,您需要使用 TFF 的声明性结构来对序列进行操作,如下所示:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE)) def train_on_one_client(model, batches): return tff.sequence_reduce(batches, model, train_on_one_batch)

我们来了解一下它的工作方式:

sample_dataset = [random_batch() for _ in range(100)] trained_model = train_on_one_client(initial_model, sample_dataset) print(average_loss(initial_model, sample_dataset)) print(average_loss(trained_model, sample_dataset))
2.3025854 2.2284968

执行单轮训练的计算和您在 TensorFlow 教程中看到的一样:

@tff.federated_computation( tff.FederatedType(MODEL_TYPE, tff.SERVER), tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS)) def train_one_round(model, federated_data): locally_trained_models = tff.federated_map( train_on_one_client, collections.OrderedDict([ ('model', tff.federated_broadcast(model)), ('batches', federated_data)])) return tff.federated_mean(locally_trained_models)

我们来了解一下它的工作方式:

trained_model = train_one_round(initial_model, train_data) print(average_loss(initial_model, eval_data)) print(average_loss(trained_model, eval_data))
2.3025854 2.282762

如您所见,在 TFF 中使用 JAX,无论是通过预设 API,还是直接使用低级 TFF 构造,都与将 TFF 与 TensorFlow 一起使用类似。请继续关注未来的更新,如果您希望看到对跨 ML 框架的互操作性的更好支持,请随时向我们发送拉取请求!