Path: blob/master/site/zh-cn/xla/tutorials/compile.ipynb
25118 views
Kernel: Python 3
Copyright 2019 The TensorFlow Authors.
In [ ]:
将 XLA 与 tf.function 结合使用
本教程将训练一个 TensorFlow 模型来对 MNIST 数据集进行分类,我们会使用 XLA 编译训练函数。
首先,加载 TensorFlow 并启用 Eager Execution。
In [ ]:
In [ ]:
随后,定义一些必要的常量并准备 MNIST 数据集。
In [ ]:
最后,定义模型和优化器。该模型使用单个密集层。
In [ ]:
定义训练函数
在训练函数中,您可以使用上面定义的层来获取预测的标签,然后使用优化器来尽可能减小损失的梯度。为了使用 XLA 编译计算,请将其置于 jit_compile=True
的 tf.function
中。
In [ ]:
训练并测试模型
定义训练函数后,请定义模型。
In [ ]:
最后,检查准确率:
In [ ]:
在后台,XLA 编译器将整个 TF 函数编译为 HLO,后者已启用融合优化。使用自省工具,我们可以查看 HLO 代码(“stage”的其他有趣的可能值是优化后的 HLO 的 optimized_hlo
和 Graphviz 计算图的 optimized_hlo_dot
):
In [ ]: