Path: blob/master/site/en-snapshot/datasets/keras_example.ipynb
25115 views
Training a neural network on MNIST with Keras
This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.
Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0
Step 1: Create your input pipeline
Start by building an efficient input pipeline using advices from:
The Performance tips guide
The Better performance with the
tf.data
API guide
Load a dataset
Load the MNIST dataset with the following arguments:
shuffle_files=True
: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.as_supervised=True
: Returns a tuple(img, label)
instead of a dictionary{'image': img, 'label': label}
.
Build a training pipeline
Apply the following transformations:
tf.data.Dataset.map
: TFDS provide images of typetf.uint8
, while the model expectstf.float32
. Therefore, you need to normalize images.tf.data.Dataset.cache
As you fit the dataset in memory, cache it before shuffling for a better performance.
Note: Random transformations should be applied after caching.tf.data.Dataset.shuffle
: For true randomness, set the shuffle buffer to the full dataset size.
Note: For large datasets that can't fit in memory, usebuffer_size=1000
if your system allows it.tf.data.Dataset.batch
: Batch elements of the dataset after shuffling to get unique batches at each epoch.tf.data.Dataset.prefetch
: It is good practice to end the pipeline by prefetching for performance.
Build an evaluation pipeline
Your testing pipeline is similar to the training pipeline with small differences:
You don't need to call
tf.data.Dataset.shuffle
.Caching is done after batching because batches can be the same between epochs.
Step 2: Create and train the model
Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.