Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/unet/carvana.py
4939 views
1
"""
2
---
3
title: Carvana dataset for the U-Net experiment
4
summary: >
5
Carvana dataset for the U-Net experiment.
6
---
7
8
# Carvana Dataset for the [U-Net](index.html) [experiment](experiment.html)
9
10
You can find the download instructions
11
[on Kaggle](https://www.kaggle.com/competitions/carvana-image-masking-challenge/data).
12
13
Save the training images inside `carvana/train` folder and the masks in `carvana/train_masks` folder.
14
"""
15
16
from pathlib import Path
17
18
import torchvision.transforms.functional
19
from PIL import Image
20
21
import torch.utils.data
22
from labml import lab
23
24
25
class CarvanaDataset(torch.utils.data.Dataset):
26
"""
27
## Carvana Dataset
28
"""
29
30
def __init__(self, image_path: Path, mask_path: Path):
31
"""
32
:param image_path: is the path to the images
33
:param mask_path: is the path to the masks
34
"""
35
# Get a dictionary of images by id
36
self.images = {p.stem: p for p in image_path.iterdir()}
37
# Get a dictionary of masks by id
38
self.masks = {p.stem[:-5]: p for p in mask_path.iterdir()}
39
40
# Image ids list
41
self.ids = list(self.images.keys())
42
43
# Transformations
44
self.transforms = torchvision.transforms.Compose([
45
torchvision.transforms.Resize(572),
46
torchvision.transforms.ToTensor(),
47
])
48
49
def __getitem__(self, idx: int):
50
"""
51
#### Get an image and its mask.
52
53
:param idx: is index of the image
54
"""
55
56
# Get image id
57
id_ = self.ids[idx]
58
# Load image
59
image = Image.open(self.images[id_])
60
# Transform image and convert it to a PyTorch tensor
61
image = self.transforms(image)
62
# Load mask
63
mask = Image.open(self.masks[id_])
64
# Transform mask and convert it to a PyTorch tensor
65
mask = self.transforms(mask)
66
67
# The mask values were not $1$, so we scale it appropriately.
68
mask = mask / mask.max()
69
70
# Return the image and the mask
71
return image, mask
72
73
def __len__(self):
74
"""
75
#### Size of the dataset
76
"""
77
return len(self.ids)
78
79
80
# Testing code
81
if __name__ == '__main__':
82
ds = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train', lab.get_data_path() / 'carvana' / 'train_masks')
83
84