Path: blob/master/site/en-snapshot/datasets/performances.md
25115 views
Performance tips
This document provides TensorFlow Datasets (TFDS)-specific performance tips. Note that TFDS provides datasets as tf.data.Dataset
objects, so the advice from the tf.data
guide still applies.
Benchmark datasets
Use tfds.benchmark(ds)
to benchmark any tf.data.Dataset
object.
Make sure to indicate the batch_size=
to normalize the results (e.g. 100 iter/sec -> 3200 ex/sec). This works with any iterable (e.g. tfds.benchmark(tfds.as_numpy(ds))
).
Small datasets (less than 1 GB)
All TFDS datasets store the data on disk in the TFRecord
format. For small datasets (e.g. MNIST, CIFAR-10/-100), reading from .tfrecord
can add significant overhead.
As those datasets fit in memory, it is possible to significantly improve the performance by caching or pre-loading the dataset. Note that TFDS automatically caches small datasets (the following section has the details).
Caching the dataset
Here is an example of a data pipeline which explicitly caches the dataset after normalizing the images.
When iterating over this dataset, the second iteration will be much faster than the first one thanks to the caching.
Auto-caching
By default, TFDS auto-caches (with ds.cache()
) datasets which satisfy the following constraints:
Total dataset size (all splits) is defined and < 250 MiB
shuffle_files
is disabled, or only a single shard is read
It is possible to opt out of auto-caching by passing try_autocaching=False
to tfds.ReadConfig
in tfds.load
. Have a look at the dataset catalog documentation to see if a specific dataset will use auto-cache.
Loading the full data as a single Tensor
If your dataset fits into memory, you can also load the full dataset as a single Tensor or NumPy array. It is possible to do so by setting batch_size=-1
to batch all examples in a single tf.Tensor
. Then use tfds.as_numpy
for the conversion from tf.Tensor
to np.array
.
Large datasets
Large datasets are sharded (split in multiple files) and typically do not fit in memory, so they should not be cached.
Shuffle and training
During training, it's important to shuffle the data well - poorly shuffled data can result in lower training accuracy.
In addition to using ds.shuffle
to shuffle records, you should also set shuffle_files=True
to get good shuffling behavior for larger datasets that are sharded into multiple files. Otherwise, epochs will read the shards in the same order, and so data won't be truly randomized.
Additionally, when shuffle_files=True
, TFDS disables options.deterministic
, which may give a slight performance boost. To get deterministic shuffling, it is possible to opt-out of this feature with tfds.ReadConfig
: either by setting read_config.shuffle_seed
or overwriting read_config.options.deterministic
.
Auto-shard your data across workers (TF)
When training on multiple workers, you can use the input_context
argument of tfds.ReadConfig
, so each worker will read a subset of the data.
This is complementary to the subsplit API. First, the subplit API is applied: train[:50%]
is converted into a list of files to read. Then, a ds.shard()
op is applied on those files. For example, when using train[:50%]
with num_input_pipelines=2
, each of the 2 workers will read 1/4 of the data.
When shuffle_files=True
, files are shuffled within one worker, but not across workers. Each worker will read the same subset of files between epochs.
Note: When using tf.distribute.Strategy
, the input_context
can be automatically created with distribute_datasets_from_function
Auto-shard your data across workers (Jax)
With Jax, you can use the tfds.split_for_jax_process
or tfds.even_splits
API to distribute your data across workers. See the split API guide.
tfds.split_for_jax_process
is a simple alias for:
Faster image decoding
By default, TFDS automatically decodes images. However, there are cases where it can be more performant to skip the image decoding with tfds.decode.SkipDecoding
and manually apply the tf.io.decode_image
op:
When filtering examples (with
tf.data.Dataset.filter
), to decode images after examples have been filtered.When cropping images, to use the fused
tf.image.decode_and_crop_jpeg
op.
The code for both examples is available in the decode guide.
Skip unused features
If you're only using a subset of the features, it is possible to entirely skip some features. If your dataset has many unused features, not decoding them can significantly improve performances. See https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features.