Path: blob/master/labml_nn/helpers/datasets.py
4918 views
import random1from pathlib import PurePath, Path2from typing import List, Callable, Dict, Optional34from torchvision import datasets, transforms56import torch7from labml import lab8from labml import monit9from labml.configs import BaseConfigs10from labml.configs import aggregate, option11from labml.utils.download import download_file12from torch.utils.data import DataLoader13from torch.utils.data import IterableDataset, Dataset141516def _mnist_dataset(is_train, transform):17return datasets.MNIST(str(lab.get_data_path()),18train=is_train,19download=True,20transform=transform)212223class MNISTConfigs(BaseConfigs):24"""25Configurable MNIST data set.2627Arguments:28dataset_name (str): name of the data set, ``MNIST``29dataset_transforms (torchvision.transforms.Compose): image transformations30train_dataset (torchvision.datasets.MNIST): training dataset31valid_dataset (torchvision.datasets.MNIST): validation dataset3233train_loader (torch.utils.data.DataLoader): training data loader34valid_loader (torch.utils.data.DataLoader): validation data loader3536train_batch_size (int): training batch size37valid_batch_size (int): validation batch size3839train_loader_shuffle (bool): whether to shuffle training data40valid_loader_shuffle (bool): whether to shuffle validation data41"""4243dataset_name: str = 'MNIST'44dataset_transforms: transforms.Compose45train_dataset: datasets.MNIST46valid_dataset: datasets.MNIST4748train_loader: DataLoader49valid_loader: DataLoader5051train_batch_size: int = 6452valid_batch_size: int = 10245354train_loader_shuffle: bool = True55valid_loader_shuffle: bool = False565758@option(MNISTConfigs.dataset_transforms)59def mnist_transforms():60return transforms.Compose([61transforms.ToTensor(),62transforms.Normalize((0.1307,), (0.3081,))63])646566@option(MNISTConfigs.train_dataset)67def mnist_train_dataset(c: MNISTConfigs):68return _mnist_dataset(True, c.dataset_transforms)697071@option(MNISTConfigs.valid_dataset)72def mnist_valid_dataset(c: MNISTConfigs):73return _mnist_dataset(False, c.dataset_transforms)747576@option(MNISTConfigs.train_loader)77def mnist_train_loader(c: MNISTConfigs):78return DataLoader(c.train_dataset,79batch_size=c.train_batch_size,80shuffle=c.train_loader_shuffle)818283@option(MNISTConfigs.valid_loader)84def mnist_valid_loader(c: MNISTConfigs):85return DataLoader(c.valid_dataset,86batch_size=c.valid_batch_size,87shuffle=c.valid_loader_shuffle)888990aggregate(MNISTConfigs.dataset_name, 'MNIST',91(MNISTConfigs.dataset_transforms, 'mnist_transforms'),92(MNISTConfigs.train_dataset, 'mnist_train_dataset'),93(MNISTConfigs.valid_dataset, 'mnist_valid_dataset'),94(MNISTConfigs.train_loader, 'mnist_train_loader'),95(MNISTConfigs.valid_loader, 'mnist_valid_loader'))969798def _cifar_dataset(is_train, transform):99return datasets.CIFAR10(str(lab.get_data_path()),100train=is_train,101download=True,102transform=transform)103104105class CIFAR10Configs(BaseConfigs):106"""107Configurable CIFAR 10 data set.108109Arguments:110dataset_name (str): name of the data set, ``CIFAR10``111dataset_transforms (torchvision.transforms.Compose): image transformations112train_dataset (torchvision.datasets.CIFAR10): training dataset113valid_dataset (torchvision.datasets.CIFAR10): validation dataset114115train_loader (torch.utils.data.DataLoader): training data loader116valid_loader (torch.utils.data.DataLoader): validation data loader117118train_batch_size (int): training batch size119valid_batch_size (int): validation batch size120121train_loader_shuffle (bool): whether to shuffle training data122valid_loader_shuffle (bool): whether to shuffle validation data123"""124dataset_name: str = 'CIFAR10'125dataset_transforms: transforms.Compose126train_dataset: datasets.CIFAR10127valid_dataset: datasets.CIFAR10128129train_loader: DataLoader130valid_loader: DataLoader131132train_batch_size: int = 64133valid_batch_size: int = 1024134135train_loader_shuffle: bool = True136valid_loader_shuffle: bool = False137138139@CIFAR10Configs.calc(CIFAR10Configs.dataset_transforms)140def cifar10_transforms():141return transforms.Compose([142transforms.ToTensor(),143transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))144])145146147@CIFAR10Configs.calc(CIFAR10Configs.train_dataset)148def cifar10_train_dataset(c: CIFAR10Configs):149return _cifar_dataset(True, c.dataset_transforms)150151152@CIFAR10Configs.calc(CIFAR10Configs.valid_dataset)153def cifar10_valid_dataset(c: CIFAR10Configs):154return _cifar_dataset(False, c.dataset_transforms)155156157@CIFAR10Configs.calc(CIFAR10Configs.train_loader)158def cifar10_train_loader(c: CIFAR10Configs):159return DataLoader(c.train_dataset,160batch_size=c.train_batch_size,161shuffle=c.train_loader_shuffle)162163164@CIFAR10Configs.calc(CIFAR10Configs.valid_loader)165def cifar10_valid_loader(c: CIFAR10Configs):166return DataLoader(c.valid_dataset,167batch_size=c.valid_batch_size,168shuffle=c.valid_loader_shuffle)169170171CIFAR10Configs.aggregate(CIFAR10Configs.dataset_name, 'CIFAR10',172(CIFAR10Configs.dataset_transforms, 'cifar10_transforms'),173(CIFAR10Configs.train_dataset, 'cifar10_train_dataset'),174(CIFAR10Configs.valid_dataset, 'cifar10_valid_dataset'),175(CIFAR10Configs.train_loader, 'cifar10_train_loader'),176(CIFAR10Configs.valid_loader, 'cifar10_valid_loader'))177178179class TextDataset:180itos: List[str]181stoi: Dict[str, int]182n_tokens: int183train: str184valid: str185standard_tokens: List[str] = []186187@staticmethod188def load(path: PurePath):189with open(str(path), 'r') as f:190return f.read()191192def __init__(self, path: PurePath, tokenizer: Callable, train: str, valid: str, test: str, *,193n_tokens: Optional[int] = None,194stoi: Optional[Dict[str, int]] = None,195itos: Optional[List[str]] = None):196self.test = test197self.valid = valid198self.train = train199self.tokenizer = tokenizer200self.path = path201202if n_tokens or stoi or itos:203assert stoi and itos and n_tokens204self.n_tokens = n_tokens205self.stoi = stoi206self.itos = itos207else:208self.n_tokens = len(self.standard_tokens)209self.stoi = {t: i for i, t in enumerate(self.standard_tokens)}210211with monit.section("Tokenize"):212tokens = self.tokenizer(self.train) + self.tokenizer(self.valid)213tokens = sorted(list(set(tokens)))214215for t in monit.iterate("Build vocabulary", tokens):216self.stoi[t] = self.n_tokens217self.n_tokens += 1218219self.itos = [''] * self.n_tokens220for t, n in self.stoi.items():221self.itos[n] = t222223def text_to_i(self, text: str) -> torch.Tensor:224tokens = self.tokenizer(text)225return torch.tensor([self.stoi[s] for s in tokens if s in self.stoi], dtype=torch.long)226227def __repr__(self):228return f'{len(self.train) / 1_000_000 :,.2f}M, {len(self.valid) / 1_000_000 :,.2f}M - {str(self.path)}'229230231class SequentialDataLoader(IterableDataset):232def __init__(self, *, text: str, dataset: TextDataset,233batch_size: int, seq_len: int):234self.seq_len = seq_len235data = dataset.text_to_i(text)236n_batch = data.shape[0] // batch_size237data = data.narrow(0, 0, n_batch * batch_size)238data = data.view(batch_size, -1).t().contiguous()239self.data = data240241def __len__(self):242return self.data.shape[0] // self.seq_len243244def __iter__(self):245self.idx = 0246return self247248def __next__(self):249if self.idx >= self.data.shape[0] - 1:250raise StopIteration()251252seq_len = min(self.seq_len, self.data.shape[0] - 1 - self.idx)253i = self.idx + seq_len254data = self.data[self.idx: i]255target = self.data[self.idx + 1: i + 1]256self.idx = i257return data, target258259def __getitem__(self, idx):260seq_len = min(self.seq_len, self.data.shape[0] - 1 - idx)261i = idx + seq_len262data = self.data[idx: i]263target = self.data[idx + 1: i + 1]264return data, target265266267class SequentialUnBatchedDataset(Dataset):268def __init__(self, *, text: str, dataset: TextDataset,269seq_len: int,270is_random_offset: bool = True):271self.is_random_offset = is_random_offset272self.seq_len = seq_len273self.data = dataset.text_to_i(text)274275def __len__(self):276return (self.data.shape[0] - 1) // self.seq_len277278def __getitem__(self, idx):279start = idx * self.seq_len280assert start + self.seq_len + 1 <= self.data.shape[0]281if self.is_random_offset:282start += random.randint(0, min(self.seq_len - 1, self.data.shape[0] - (start + self.seq_len + 1)))283284end = start + self.seq_len285data = self.data[start: end]286target = self.data[start + 1: end + 1]287return data, target288289290class TextFileDataset(TextDataset):291standard_tokens = []292293def __init__(self, path: PurePath, tokenizer: Callable, *,294url: Optional[str] = None,295filter_subset: Optional[int] = None):296path = Path(path)297if not path.exists():298if not url:299raise FileNotFoundError(str(path))300else:301download_file(url, path)302303with monit.section("Load data"):304text = self.load(path)305if filter_subset:306text = text[:filter_subset]307split = int(len(text) * .9)308train = text[:split]309valid = text[split:]310311super().__init__(path, tokenizer, train, valid, '')312313314def _test_tiny_shakespeare():315from labml import lab316_ = TextFileDataset(lab.get_data_path() / 'tiny_shakespeare.txt', lambda x: list(x),317url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')318319320if __name__ == '__main__':321_test_tiny_shakespeare()322323324