Path: blob/master/site/en-snapshot/guide/basic_training_loops.ipynb
25115 views
Copyright 2020 The TensorFlow Authors.
Basic training loops
In the previous guides, you have learned about tensors, variables, gradient tape, and modules. In this guide, you will fit these all together to train models.
TensorFlow also includes the tf.Keras API, a high-level neural network API that provides useful abstractions to reduce boilerplate. However, in this guide, you will use basic classes.
Setup
Solving machine learning problems
Solving a machine learning problem usually consists of the following steps:
Obtain training data.
Define the model.
Define a loss function.
Run through the training data, calculating loss from the ideal value
Calculate gradients for that loss and use an optimizer to adjust the variables to fit the data.
Evaluate your results.
For illustration purposes, in this guide you'll develop a simple linear model, , which has two variables: (weights) and (bias).
This is the most basic of machine learning problems: Given and , try to find the slope and offset of a line via simple linear regression.
Data
Supervised learning uses inputs (usually denoted as x) and outputs (denoted y, often called labels). The goal is to learn from paired inputs and outputs so that you can predict the value of an output from an input.
Each input of your data, in TensorFlow, is almost always represented by a tensor, and is often a vector. In supervised training, the output (or value you'd like to predict) is also a tensor.
Here is some data synthesized by adding Gaussian (Normal) noise to points along a line.
Tensors are usually gathered together in batches, or groups of inputs and outputs stacked together. Batching can confer some training benefits and works well with accelerators and vectorized computation. Given how small this dataset is, you can treat the entire dataset as a single batch.
Define the model
Use tf.Variable
to represent all weights in a model. A tf.Variable
stores a value and provides this in tensor form as needed. See the variable guide for more details.
Use tf.Module
to encapsulate the variables and the computation. You could use any Python object, but this way it can be easily saved.
Here, you define both w and b as variables.
The initial variables are set here in a fixed way, but Keras comes with any of a number of initializers you could use, with or without the rest of Keras.
Define a loss function
A loss function measures how well the output of a model for a given input matches the target output. The goal is to minimize this difference during training. Define the standard L2 loss, also known as the "mean squared" error:
Before training the model, you can visualize the loss value by plotting the model's predictions in red and the training data in blue:
Define a training loop
The training loop consists of repeatedly doing three tasks in order:
Sending a batch of inputs through the model to generate outputs
Calculating the loss by comparing the outputs to the output (or label)
Using gradient tape to find the gradients
Optimizing the variables with those gradients
For this example, you can train the model using gradient descent.
There are many variants of the gradient descent scheme that are captured in tf.keras.optimizers
. But in the spirit of building from first principles, here you will implement the basic math yourself with the help of tf.GradientTape
for automatic differentiation and tf.assign_sub
for decrementing a value (which combines tf.assign
and tf.sub
):
For a look at training, you can send the same batch of x and y through the training loop, and see how W
and b
evolve.
Do the training
Plot the evolution of the weights over time:
Visualize how the trained model performs
The same solution, but with Keras
It's useful to contrast the code above with the equivalent in Keras.
Defining the model looks exactly the same if you subclass tf.keras.Model
. Remember that Keras models inherit ultimately from module.
Rather than write new training loops each time you create a model, you can use the built-in features of Keras as a shortcut. This can be useful when you do not want to write or debug Python training loops.
If you do, you will need to use model.compile()
to set the parameters, and model.fit()
to train. It can be less code to use Keras implementations of L2 loss and gradient descent, again as a shortcut. Keras losses and optimizers can be used outside of these convenience functions, too, and the previous example could have used them.
Keras fit
expects batched data or a complete dataset as a NumPy array. NumPy arrays are chopped into batches and default to a batch size of 32.
In this case, to match the behavior of the hand-written loop, you should pass x
in as a single batch of size 1000.
Note that Keras prints out the loss after training, not before, so the first loss appears lower, but otherwise this shows essentially the same training performance.
Next steps
In this guide, you have seen how to use the core classes of tensors, variables, modules, and gradient tape to build and train a model, and further how those ideas map to Keras.
This is, however, an extremely simple problem. For a more practical introduction, see Custom training walkthrough.
For more on using built-in Keras training loops, see this guide. For more on training loops and Keras, see this guide. For writing custom distributed training loops, see this guide.