Path: blob/master/labml_nn/unet/carvana.py
4939 views
"""1---2title: Carvana dataset for the U-Net experiment3summary: >4Carvana dataset for the U-Net experiment.5---67# Carvana Dataset for the [U-Net](index.html) [experiment](experiment.html)89You can find the download instructions10[on Kaggle](https://www.kaggle.com/competitions/carvana-image-masking-challenge/data).1112Save the training images inside `carvana/train` folder and the masks in `carvana/train_masks` folder.13"""1415from pathlib import Path1617import torchvision.transforms.functional18from PIL import Image1920import torch.utils.data21from labml import lab222324class CarvanaDataset(torch.utils.data.Dataset):25"""26## Carvana Dataset27"""2829def __init__(self, image_path: Path, mask_path: Path):30"""31:param image_path: is the path to the images32:param mask_path: is the path to the masks33"""34# Get a dictionary of images by id35self.images = {p.stem: p for p in image_path.iterdir()}36# Get a dictionary of masks by id37self.masks = {p.stem[:-5]: p for p in mask_path.iterdir()}3839# Image ids list40self.ids = list(self.images.keys())4142# Transformations43self.transforms = torchvision.transforms.Compose([44torchvision.transforms.Resize(572),45torchvision.transforms.ToTensor(),46])4748def __getitem__(self, idx: int):49"""50#### Get an image and its mask.5152:param idx: is index of the image53"""5455# Get image id56id_ = self.ids[idx]57# Load image58image = Image.open(self.images[id_])59# Transform image and convert it to a PyTorch tensor60image = self.transforms(image)61# Load mask62mask = Image.open(self.masks[id_])63# Transform mask and convert it to a PyTorch tensor64mask = self.transforms(mask)6566# The mask values were not $1$, so we scale it appropriately.67mask = mask / mask.max()6869# Return the image and the mask70return image, mask7172def __len__(self):73"""74#### Size of the dataset75"""76return len(self.ids)777879# Testing code80if __name__ == '__main__':81ds = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train', lab.get_data_path() / 'carvana' / 'train_masks')828384