Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/en-snapshot/datasets/tfless_tfds.ipynb
25115 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

TFDS for Jax and PyTorch

TFDS has always been framework-agnostic. For instance, you can easily load datasets in NumPy format for usage in Jax and PyTorch.

TensorFlow and its data loading solution (tf.data) are first-class citizens in our API by design.

We extended TFDS to support TensorFlow-less NumPy-only data loading. This can be convenient for usage in ML frameworks such as Jax and PyTorch. Indeed, for the latter users, TensorFlow can:

  • reserve GPU/TPU memory;

  • increase build time in CI/CD;

  • take time to import at runtime.

TensorFlow is no longer a dependency to read datasets.

ML pipelines need a data loader to load examples, decode them, and present them to the model. Data loaders use the "source/sampler/loader" paradigm:

TFDS dataset ┌────────────────┐ on disk │ │ ┌──────────►│ Data │ |..|... │ | │ source ├─┐ ├──┼────┴─────┤ │ │ │ │12│image12 │ └────────────────┘ │ ┌────────────────┐ ├──┼──────────┤ │ │ │ │13│image13 │ ├───►│ Data ├───► ML pipeline ├──┼──────────┤ │ │ loader │ │14│image14 │ ┌────────────────┐ │ │ │ ├──┼──────────┤ │ │ │ └────────────────┘ |..|... | │ Index ├─┘ │ sampler │ │ │ └────────────────┘
  • The data source is responsible for accessing and decoding examples from a TFDS dataset on the fly.

  • The index sampler is responsible for determining the order in which records are processed. This is important to implement global transformations (e.g., global shuffling, sharding, repeating for multiple epochs) before reading any records.

  • The data loader orchestrates the loading by leveraging the data source and the index sampler. It allows performance optimization (e.g., pre-fetching, multiprocessing or multithreading).

TL;DR

tfds.data_source is an API to create data sources:

  1. for fast prototyping in pure-Python pipelines;

  2. to manage data-intensive ML pipelines at scale.

Setup

Let's install and import the needed dependencies:

!pip install array_record !pip install tfds-nightly import os os.environ.pop('TFDS_DATA_DIR', None) import tensorflow_datasets as tfds

Data sources

Data sources are basically Python sequences. So they need to implement the following protocol:

class RandomAccessDataSource(Protocol): """Interface for datasources where storage supports efficient random access.""" def __len__(self) -> int: """Number of records in the dataset.""" def __getitem__(self, record_key: int) -> Sequence[Any]: """Retrieves records for the given record_keys."""

Warning: the API is still under active development. Notably, at this point, __getitem__ must support both int and list[int] in input. In the future, it will probably only support int as per the standard.

The underlying file format needs to support efficient random access. At the moment, TFDS relies on array_record.

array_record is a new file format derived from Riegeli, achieving a new frontier of IO efficiency. In particular, ArrayRecord supports parallel read, write, and random access by record index. ArrayRecord builds on top of Riegeli and supports the same compression algorithms.

fashion_mnist is a common dataset for computer vision. To retrieve an ArrayRecord-based data source with TFDS, simply use:

ds = tfds.data_source('fashion_mnist')

tfds.data_source is a convenient wrapper. It is equivalent to:

builder = tfds.builder('fashion_mnist', file_format='array_record') builder.download_and_prepare() ds = builder.as_data_source()

This outputs a dictionary of data sources:

{ 'train': DataSource(name=fashion_mnist, split='train', decoders=None), 'test': DataSource(name=fashion_mnist, split='test', decoders=None), }

Once download_and_prepare has run, and you generated the record files, we don't need TensorFlow anymore. Everything will happen in Python/NumPy!

Let's check this by uninstalling TensorFlow and re-loading the data source in another subprocess:

!pip uninstall -y tensorflow
%%writefile no_tensorflow.py import os os.environ.pop('TFDS_DATA_DIR', None) import tensorflow_datasets as tfds try: import tensorflow as tf except ImportError: print('No TensorFlow found...') ds = tfds.data_source('fashion_mnist') print('...but the data source could still be loaded...') ds['train'][0] print('...and the records can be decoded.')
!python no_tensorflow.py

In future versions, we are also going to make the dataset preparation TensorFlow-free.

A data source has a length:

len(ds['train'])

Accessing the first element of the dataset:

%%timeit ds['train'][0]

...is just as cheap as accessing any other element. This is the definition of random access:

%%timeit ds['train'][1000]

Features now use NumPy DTypes (rather than TensorFlow DTypes). You can inspect the features with:

features = tfds.builder('fashion_mnist').info.features

You'll find more information about the features in our documentation. Here we can notably retrieve the shape of the images, and the number of classes:

shape = features['image'].shape num_classes = features['label'].num_classes

Use in pure Python

You can consume data sources in Python by iterating over them:

for example in ds['train']: print(example) break

If you inspect elements, you will also notice that all features are already decoded using NumPy. Behind the scenes, we use OpenCV by default because it is fast. If you don't have OpenCV installed, we default to Pillow to provide lightweight and fast image decoding.

{ 'image': array([[[0], [0], ..., [0]], [[0], [0], ..., [0]]], dtype=uint8), 'label': 2, }

Note: Currently, the feature is only available for Tensor, Image and Scalar features. The Audio and Video features will come soon. Stay tuned!

Use with PyTorch

PyTorch uses the source/sampler/loader paradigm. In Torch, "data sources" are called "datasets". torch.utils.data contains all the details you need to know to build efficient input pipelines in Torch.

TFDS data sources can be used as regular map-style datasets.

First we install and import Torch:

!pip install torch from tqdm import tqdm import torch

We already defined data sources for training and testing (respectively, ds['train'] and ds['test']). We can now define the sampler and the loaders:

batch_size = 128 train_sampler = torch.utils.data.RandomSampler(ds['train'], num_samples=5_000) train_loader = torch.utils.data.DataLoader( ds['train'], sampler=train_sampler, batch_size=batch_size, ) test_loader = torch.utils.data.DataLoader( ds['test'], sampler=None, batch_size=batch_size, )

Using PyTorch, we train and evaluate a simple logistic regression on the first examples:

class LinearClassifier(torch.nn.Module): def __init__(self, shape, num_classes): super(LinearClassifier, self).__init__() height, width, channels = shape self.classifier = torch.nn.Linear(height * width * channels, num_classes) def forward(self, image): image = image.view(image.size()[0], -1).to(torch.float32) return self.classifier(image) model = LinearClassifier(shape, num_classes) optimizer = torch.optim.Adam(model.parameters()) loss_function = torch.nn.CrossEntropyLoss() print('Training...') model.train() for example in tqdm(train_loader): image, label = example['image'], example['label'] prediction = model(image) loss = loss_function(prediction, label) optimizer.zero_grad() loss.backward() optimizer.step() print('Testing...') model.eval() num_examples = 0 true_positives = 0 for example in tqdm(test_loader): image, label = example['image'], example['label'] prediction = model(image) num_examples += image.shape[0] predicted_label = prediction.argmax(dim=1) true_positives += (predicted_label == label).sum().item() print(f'\nAccuracy: {true_positives/num_examples * 100:.2f}%')

Coming soon: use with JAX

We are working closely with Grain. Grain is an open-source, fast and deterministic data loader for Python. So stay tuned!

Read more

For more information, please refer to tfds.data_source API doc.