Path: blob/master/site/en-snapshot/guide/effective_tf2.ipynb
39053 views
Copyright 2020 The TensorFlow Authors.
Effective Tensorflow 2
Overview
This guide provides a list of best practices for writing code using TensorFlow 2 (TF2), it is written for users who have recently switched over from TensorFlow 1 (TF1). Refer to the migrate section of the guide for more info on migrating your TF1 code to TF2.
Setup
Import TensorFlow and other dependencies for the examples in this guide.
Recommendations for idiomatic TensorFlow 2
Refactor your code into smaller modules
A good practice is to refactor your code into smaller functions that are called as needed. For best performance, you should try to decorate the largest blocks of computation that you can in a tf.function (note that the nested python functions called by a tf.function do not require their own separate decorations, unless you want to use different jit_compile settings for the tf.function). Depending on your use case, this could be multiple training steps or even your whole training loop. For inference use cases, it might be a single model forward pass.
Adjust the default learning rate for some tf.keras.optimizers
Some Keras optimizers have different learning rates in TF2. If you see a change in convergence behavior for your models, check the default learning rates.
There are no changes for optimizers.SGD, optimizers.Adam, or optimizers.RMSprop.
The following default learning rates have changed:
optimizers.Adagradfrom0.01to0.001optimizers.Adadeltafrom1.0to0.001optimizers.Adamaxfrom0.002to0.001optimizers.Nadamfrom0.002to0.001
Use tf.Modules and Keras layers to manage variables
tf.Modules and tf.keras.layers.Layers offer the convenient variables and trainable_variables properties, which recursively gather up all dependent variables. This makes it easy to manage variables locally to where they are being used.
Keras layers/models inherit from tf.train.Checkpointable and are integrated with @tf.function, which makes it possible to directly checkpoint or export SavedModels from Keras objects. You do not necessarily have to use Keras' Model.fit API to take advantage of these integrations.
Read the section on transfer learning and fine-tuning in the Keras guide to learn how to collect a subset of relevant variables using Keras.
Combine tf.data.Datasets and tf.function
The TensorFlow Datasets package (tfds) contains utilities for loading predefined datasets as tf.data.Dataset objects. For this example, you can load the MNIST dataset using tfds:
Then prepare the data for training:
Re-scale each image.
Shuffle the order of the examples.
Collect batches of images and labels.
To keep the example short, trim the dataset to only return 5 batches:
Use regular Python iteration to iterate over training data that fits in memory. Otherwise, tf.data.Dataset is the best way to stream training data from disk. Datasets are iterables (not iterators), and work just like other Python iterables in eager execution. You can fully utilize dataset async prefetching/streaming features by wrapping your code in tf.function, which replaces Python iteration with the equivalent graph operations using AutoGraph.
If you use the Keras Model.fit API, you won't have to worry about dataset iteration.
Use Keras training loops
If you don't need low-level control of your training process, using Keras' built-in fit, evaluate, and predict methods is recommended. These methods provide a uniform interface to train the model regardless of the implementation (sequential, functional, or sub-classed).
The advantages of these methods include:
They accept Numpy arrays, Python generators and,
tf.data.Datasets.They apply regularization, and activation losses automatically.
They support
tf.distributewhere the training code remains the same regardless of the hardware configuration.They support arbitrary callables as losses and metrics.
They support callbacks like
tf.keras.callbacks.TensorBoard, and custom callbacks.They are performant, automatically using TensorFlow graphs.
Here is an example of training a model using a Dataset. For details on how this works, check out the tutorials.
Customize training and write your own loop
If Keras models work for you, but you need more flexibility and control of the training step or the outer training loops, you can implement your own training steps or even entire training loops. See the Keras guide on customizing fit to learn more.
You can also implement many things as a tf.keras.callbacks.Callback.
This method has many of the advantages mentioned previously, but gives you control of the train step and even the outer loop.
There are three steps to a standard training loop:
Iterate over a Python generator or
tf.data.Datasetto get batches of examples.Use
tf.GradientTapeto collect gradients.Use one of the
tf.keras.optimizersto apply weight updates to the model's variables.
Remember:
Always include a
trainingargument on thecallmethod of subclassed layers and models.Make sure to call the model with the
trainingargument set correctly.Depending on usage, model variables may not exist until the model is run on a batch of data.
You need to manually handle things like regularization losses for the model.
There is no need to run variable initializers or to add manual control dependencies. tf.function handles automatic control dependencies and variable initialization on creation for you.
Take advantage of tf.function with Python control flow
tf.function provides a way to convert data-dependent control flow into graph-mode equivalents like tf.cond and tf.while_loop.
One common place where data-dependent control flow appears is in sequence models. tf.keras.layers.RNN wraps an RNN cell, allowing you to either statically or dynamically unroll the recurrence. As an example, you could reimplement dynamic unroll as follows.
Read the tf.function guide for a more information.
New-style metrics and losses
Metrics and losses are both objects that work eagerly and in tf.functions.
A loss object is callable, and expects (y_true, y_pred) as arguments:
Use metrics to collect and display data
You can use tf.metrics to aggregate data and tf.summary to log summaries and redirect it to a writer using a context manager. The summaries are emitted directly to the writer which means that you must provide the step value at the callsite.
Use tf.metrics to aggregate data before logging them as summaries. Metrics are stateful; they accumulate values and return a cumulative result when you call the result method (such as Mean.result). Clear accumulated values with Model.reset_states.
Visualize the generated summaries by pointing TensorBoard to the summary log directory:
Use the tf.summary API to write summary data for visualization in TensorBoard. For more info, read the tf.summary guide.
Keras metric names
Keras models are consistent about handling metric names. When you pass a string in the list of metrics, that exact string is used as the metric's name. These names are visible in the history object returned by model.fit, and in the logs passed to keras.callbacks. is set to the string you passed in the metric list.
Debugging
Use eager execution to run your code step-by-step to inspect shapes, data types and values. Certain APIs, like tf.function, tf.keras, etc. are designed to use Graph execution, for performance and portability. When debugging, use tf.config.run_functions_eagerly(True) to use eager execution inside this code.
For example:
This also works inside Keras models and other APIs that support eager execution:
Notes:
tf.keras.Modelmethods such asfit,evaluate, andpredictexecute as graphs withtf.functionunder the hood.When using
tf.keras.Model.compile, setrun_eagerly = Trueto disable theModellogic from being wrapped in atf.function.Use
tf.data.experimental.enable_debug_modeto enable the debug mode fortf.data. Read the API docs for more details.
Do not keep tf.Tensors in your objects
These tensor objects might get created either in a tf.function or in the eager context, and these tensors behave differently. Always use tf.Tensors only for intermediate values.
To track state, use tf.Variables as they are always usable from both contexts. Read the tf.Variable guide to learn more.
Resources and further reading
Read the TF2 guides and tutorials to learn more about how to use TF2.
If you previously used TF1.x, it is highly recommended you migrate your code to TF2. Read the migration guides to learn more.
View on TensorFlow.org
Run in Google Colab
View on GitHub
Download notebook