Path: blob/master/site/en-snapshot/xla/tutorials/jit_compile.ipynb
25118 views
Copyright 2019 The TensorFlow Authors.
Use XLA with tf.function
This tutorial trains a TensorFlow model to classify the MNIST dataset, where the training function is compiled using XLA.
First, load TensorFlow and enable eager execution.
Then define some necessary constants and prepare the MNIST dataset.
Finally, define the model and the optimizer. The model uses a single dense layer.
Define the training function
In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside tf.function
with jit_compile=True
.
Train and test the model
Once you have defined the training function, define the model.
And, finally, check the accuracy:
Behind the scenes, the XLA compiler has compiled the entire TF function to HLO, which has enabled fusion optimizations. Using the introspection facilities, we can see the HLO code (other interesting possible values for "stage" are optimized_hlo
for HLO after optimizations and optimized_hlo_dot
for a Graphviz graph):