Path: blob/master/site/zh-cn/federated/tutorials/jax_support.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
TFF 中对 JAX 的实验性支持
除了作为 TensorFlow 生态系统的一部分,TFF 还致力于实现与其他前端和后端 ML 框架的互操作性。目前,对其他 ML 框架的支持仍处于孵化阶段,所支持的 API 和功能可能会发生变化(很大程度上取决于 TFF 用户的需求)。本教程描述如何使用 TFF 和 JAX 作为替代的 ML 前端,以及如何使用 XLA 编译器作为替代的后端。这里展示的示例基于一个完全原生的端到端 JAX/XLA 堆栈。跨框架混合代码(例如,JAX 和 TensorFlow)的可能性将在未来的教程中进行讨论。
我们一如既往地欢迎您的贡献。如果对 JAX/XLA 的支持或与其他 ML 框架的互操作能力对您来说很重要,请考虑帮助我们发展这些功能,使其与 TFF 的其他功能持平。
开始之前
关于如何配置您的环境,请参阅 TFF 文档的正文。根据运行本教程的位置,您可能希望取消注释并运行下面的部分或全部代码。
本教程还假设您已经阅读了 TFF 的主要 TensorFlow 教程,并且熟悉 TFF 核心概念。如果您还没有阅读这些内容,请考虑至少审阅其中一个。
JAX 计算
在 TFF 中对 JAX 的支持被设计成与 TFF 与 TensorFlow 互操作的方式对称,从导入开始:
此外,与 TensorFlow 一样,表达任何 TFF 代码的基础是本地运行的逻辑。如下所示,您可以使用 @tff.jax_computation
封装容器在 JAX 中表达该逻辑。它的行为类似于您现在所熟悉的 @tff.tf_computation
。我们先从简单的内容开始,例如,将两个整数相加的计算:
您可以像通常使用 TFF计 算一样使用上面定义的 JAX 计算。例如,您可以查看其类型签名,如下所示:
请注意,我们使用了 np.int32
来定义参数类型。TFF 不区分 Numpy 类型(如 np.int32
)和 TensorFlow 类型(如 tf.int32
)。从 TFF 的角度来看,它们只是指代同一事物的不同方式。
接下来,请记住 TFF 不是 Python(如果您不熟悉此内容,请查看我们之前的部分教程,如,有关自定义算法的内容)。您可以将 @tff.jax_computation
封装容器与任意可以跟踪和序列化的 JAX 代码一起使用,即您通常会使用 @jax.jit
进行注解并应被编译成 XLA 的代码(但您无需真的使用 @jax.jit
注解将 JAX 代码嵌入到 TFF 中)。
实际上,TFF 会在底层立即将 JAX 计算编译成 XLA。您可以通过手动从 add_numbers
中提取和打印序列化的 XLA 代码来亲自查看,如下所示:
对于在 TensorFlow 表达的计算,可以将表示为 XLA 代码的 JAX 计算看作 tf.GraphDef
的功能对等项。它可移植并可在各种支持 XLA 的环境中执行,就像 tf.GraphDef
可以在任何 TensorFlow 运行时上执行一样。
TFF 提供了一个基于 XLA 编译器的运行时堆栈作为后端。可以通过以下方式激活:
现在,您可以执行我们上面定义的计算:
很简单。我们来继续做一些更复杂的事情,比如 MNIST。
使用预设 API 的 MNIST 训练示例
像往常一样,我们首先为数据批次和模型定义一组 TFF 类型(请记住,TFF 是一个强类型框架)。
现在,我们以模型和单批数据为参数,在 JAX 中为模型定义一个损失函数:
现在,一种方法是使用预设 API。
下面是一个示例,演示如何使用我们的 API 根据刚才定义的损失函数创建训练流程。
您可以像使用 TensorFlow 中的 tf.Keras
模型的 trainer 构建一样,使用上面的代码。例如,以下是如何为训练创建初始模型的方法:
为了进行实际训练,我们需要一些数据。简单起见,我们随机生成数据。由于数据是随机的,我们将对训练数据进行评估,否则,对于随机评估数据,很难期望模型正常执行。此外,对于这个小规模演示,我们无需担心随机采样的客户端(我们会将其作为练习,让用户按照其他教程中的模板来探索这些类型的变化):
有了这些准备工作,我们可以执行单个步骤的训练,如下所示:
我们来评估一下训练步骤的结果。简单起见,我们可以对其进行集中评估:
损失正在减少。太棒了!现在,我们来多运行几轮:
如您所见,尽管实验性 API 在功能上还不能与 TensorFlow API 相提并论,将 JAX 与 TFF 配合使用并没有太大的不同。
在底层
如果您不喜欢使用我们的预设 API,您可以实现您自己的自定义计算。所用方法与您在 TensorFlow 的自定义算法教程中看到的大致相同,只是您将使用 JAX 的梯度下降机制。例如,下面是定义能够在单个 mini-batch 上更新模型的 JAX 计算的方法:
下面是测试其是否能够正常工作的方法:
使用 JAX 需要注意的一点是,它不提供与 tf.data.Dataset
相同的功能。因此,为了迭代数据集,您需要使用 TFF 的声明性结构来对序列进行操作,如下所示:
我们来了解一下它的工作方式:
执行单轮训练的计算和您在 TensorFlow 教程中看到的一样:
我们来了解一下它的工作方式:
如您所见,在 TFF 中使用 JAX,无论是通过预设 API,还是直接使用低级 TFF 构造,都与将 TFF 与 TensorFlow 一起使用类似。请继续关注未来的更新,如果您希望看到对跨 ML 框架的互操作性的更好支持,请随时向我们发送拉取请求!