Path: blob/master/site/en-snapshot/guide/effective_tf2.ipynb
25115 views
Copyright 2020 The TensorFlow Authors.
Effective Tensorflow 2
Overview
This guide provides a list of best practices for writing code using TensorFlow 2 (TF2), it is written for users who have recently switched over from TensorFlow 1 (TF1). Refer to the migrate section of the guide for more info on migrating your TF1 code to TF2.
Setup
Import TensorFlow and other dependencies for the examples in this guide.
Recommendations for idiomatic TensorFlow 2
Refactor your code into smaller modules
A good practice is to refactor your code into smaller functions that are called as needed. For best performance, you should try to decorate the largest blocks of computation that you can in a tf.function
(note that the nested python functions called by a tf.function
do not require their own separate decorations, unless you want to use different jit_compile
settings for the tf.function
). Depending on your use case, this could be multiple training steps or even your whole training loop. For inference use cases, it might be a single model forward pass.
Adjust the default learning rate for some tf.keras.optimizer
s
Some Keras optimizers have different learning rates in TF2. If you see a change in convergence behavior for your models, check the default learning rates.
There are no changes for optimizers.SGD
, optimizers.Adam
, or optimizers.RMSprop
.
The following default learning rates have changed:
optimizers.Adagrad
from0.01
to0.001
optimizers.Adadelta
from1.0
to0.001
optimizers.Adamax
from0.002
to0.001
optimizers.Nadam
from0.002
to0.001
Use tf.Module
s and Keras layers to manage variables
tf.Module
s and tf.keras.layers.Layer
s offer the convenient variables
and trainable_variables
properties, which recursively gather up all dependent variables. This makes it easy to manage variables locally to where they are being used.
Keras layers/models inherit from tf.train.Checkpointable
and are integrated with @tf.function
, which makes it possible to directly checkpoint or export SavedModels from Keras objects. You do not necessarily have to use Keras' Model.fit
API to take advantage of these integrations.
Read the section on transfer learning and fine-tuning in the Keras guide to learn how to collect a subset of relevant variables using Keras.
Combine tf.data.Dataset
s and tf.function
The TensorFlow Datasets package (tfds
) contains utilities for loading predefined datasets as tf.data.Dataset
objects. For this example, you can load the MNIST dataset using tfds
:
Then prepare the data for training:
Re-scale each image.
Shuffle the order of the examples.
Collect batches of images and labels.
To keep the example short, trim the dataset to only return 5 batches:
Use regular Python iteration to iterate over training data that fits in memory. Otherwise, tf.data.Dataset
is the best way to stream training data from disk. Datasets are iterables (not iterators), and work just like other Python iterables in eager execution. You can fully utilize dataset async prefetching/streaming features by wrapping your code in tf.function
, which replaces Python iteration with the equivalent graph operations using AutoGraph.
If you use the Keras Model.fit
API, you won't have to worry about dataset iteration.
Use Keras training loops
If you don't need low-level control of your training process, using Keras' built-in fit
, evaluate
, and predict
methods is recommended. These methods provide a uniform interface to train the model regardless of the implementation (sequential, functional, or sub-classed).
The advantages of these methods include:
They accept Numpy arrays, Python generators and,
tf.data.Datasets
.They apply regularization, and activation losses automatically.
They support
tf.distribute
where the training code remains the same regardless of the hardware configuration.They support arbitrary callables as losses and metrics.
They support callbacks like
tf.keras.callbacks.TensorBoard
, and custom callbacks.They are performant, automatically using TensorFlow graphs.
Here is an example of training a model using a Dataset
. For details on how this works, check out the tutorials.
Customize training and write your own loop
If Keras models work for you, but you need more flexibility and control of the training step or the outer training loops, you can implement your own training steps or even entire training loops. See the Keras guide on customizing fit
to learn more.
You can also implement many things as a tf.keras.callbacks.Callback
.
This method has many of the advantages mentioned previously, but gives you control of the train step and even the outer loop.
There are three steps to a standard training loop:
Iterate over a Python generator or
tf.data.Dataset
to get batches of examples.Use
tf.GradientTape
to collect gradients.Use one of the
tf.keras.optimizers
to apply weight updates to the model's variables.
Remember:
Always include a
training
argument on thecall
method of subclassed layers and models.Make sure to call the model with the
training
argument set correctly.Depending on usage, model variables may not exist until the model is run on a batch of data.
You need to manually handle things like regularization losses for the model.
There is no need to run variable initializers or to add manual control dependencies. tf.function
handles automatic control dependencies and variable initialization on creation for you.
Take advantage of tf.function
with Python control flow
tf.function
provides a way to convert data-dependent control flow into graph-mode equivalents like tf.cond
and tf.while_loop
.
One common place where data-dependent control flow appears is in sequence models. tf.keras.layers.RNN
wraps an RNN cell, allowing you to either statically or dynamically unroll the recurrence. As an example, you could reimplement dynamic unroll as follows.
Read the tf.function
guide for a more information.
New-style metrics and losses
Metrics and losses are both objects that work eagerly and in tf.function
s.
A loss object is callable, and expects (y_true
, y_pred
) as arguments:
Use metrics to collect and display data
You can use tf.metrics
to aggregate data and tf.summary
to log summaries and redirect it to a writer using a context manager. The summaries are emitted directly to the writer which means that you must provide the step
value at the callsite.
Use tf.metrics
to aggregate data before logging them as summaries. Metrics are stateful; they accumulate values and return a cumulative result when you call the result
method (such as Mean.result
). Clear accumulated values with Model.reset_states
.
Visualize the generated summaries by pointing TensorBoard to the summary log directory:
Use the tf.summary
API to write summary data for visualization in TensorBoard. For more info, read the tf.summary
guide.
Keras metric names
Keras models are consistent about handling metric names. When you pass a string in the list of metrics, that exact string is used as the metric's name
. These names are visible in the history object returned by model.fit
, and in the logs passed to keras.callbacks
. is set to the string you passed in the metric list.
Debugging
Use eager execution to run your code step-by-step to inspect shapes, data types and values. Certain APIs, like tf.function
, tf.keras
, etc. are designed to use Graph execution, for performance and portability. When debugging, use tf.config.run_functions_eagerly(True)
to use eager execution inside this code.
For example:
This also works inside Keras models and other APIs that support eager execution:
Notes:
tf.keras.Model
methods such asfit
,evaluate
, andpredict
execute as graphs withtf.function
under the hood.When using
tf.keras.Model.compile
, setrun_eagerly = True
to disable theModel
logic from being wrapped in atf.function
.Use
tf.data.experimental.enable_debug_mode
to enable the debug mode fortf.data
. Read the API docs for more details.
Do not keep tf.Tensors
in your objects
These tensor objects might get created either in a tf.function
or in the eager context, and these tensors behave differently. Always use tf.Tensor
s only for intermediate values.
To track state, use tf.Variable
s as they are always usable from both contexts. Read the tf.Variable
guide to learn more.
Resources and further reading
Read the TF2 guides and tutorials to learn more about how to use TF2.
If you previously used TF1.x, it is highly recommended you migrate your code to TF2. Read the migration guides to learn more.