Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/data.py
1192 views
1
"""
2
This data utils file for celeba is from https://github.com/sayantanauddy/vae_lightning
3
"""
4
5
from functools import partial
6
import pandas as pd
7
import os
8
import PIL
9
import glob
10
11
import torch
12
from torch.utils.data import Dataset, DataLoader, random_split
13
from torchvision import transforms, utils, io
14
from torchvision.datasets.utils import verify_str_arg
15
16
import pytorch_lightning as pl
17
18
19
class CelebADataset(Dataset):
20
"""CelebA Dataset class"""
21
22
def __init__(self, root, split="train", target_type="attr", transform=None, target_transform=None, download=False):
23
""" """
24
25
self.root = root
26
self.split = split
27
self.target_type = target_type
28
self.transform = transform
29
self.target_transform = target_transform
30
31
if isinstance(target_type, list):
32
self.target_type = target_type
33
else:
34
self.target_type = [target_type]
35
36
if not self.target_type and self.target_transform is not None:
37
raise RuntimeError("target_transform is specified but target_type is empty")
38
39
if download:
40
self.download_from_kaggle()
41
42
split_map = {
43
"train": 0,
44
"valid": 1,
45
"test": 2,
46
"all": None,
47
}
48
49
split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
50
51
fn = partial(os.path.join, self.root)
52
splits = pd.read_csv(fn("list_eval_partition.csv"), delim_whitespace=False, header=0, index_col=0)
53
# This file is not available in Kaggle
54
# identity = pd.read_csv(fn("identity_CelebA.csv"), delim_whitespace=True, header=None, index_col=0)
55
bbox = pd.read_csv(fn("list_bbox_celeba.csv"), delim_whitespace=False, header=0, index_col=0)
56
landmarks_align = pd.read_csv(
57
fn("list_landmarks_align_celeba.csv"), delim_whitespace=False, header=0, index_col=0
58
)
59
attr = pd.read_csv(fn("list_attr_celeba.csv"), delim_whitespace=False, header=0, index_col=0)
60
61
mask = slice(None) if split_ is None else (splits["partition"] == split_)
62
63
self.filename = splits[mask].index.values
64
# self.identity = torch.as_tensor(identity[mask].values)
65
self.bbox = torch.as_tensor(bbox[mask].values)
66
self.landmarks_align = torch.as_tensor(landmarks_align[mask].values)
67
self.attr = torch.as_tensor(attr[mask].values)
68
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
69
self.attr_names = list(attr.columns)
70
71
def download_from_kaggle(self):
72
73
# Annotation files will be downloaded at the end
74
label_files = [
75
"list_attr_celeba.csv",
76
"list_bbox_celeba.csv",
77
"list_eval_partition.csv",
78
"list_landmarks_align_celeba.csv",
79
]
80
81
# Check if files have been downloaded already
82
files_exist = False
83
for label_file in label_files:
84
if os.path.isfile(os.path.join(self.root, label_file)):
85
files_exist = True
86
else:
87
files_exist = False
88
89
if files_exist:
90
print("Files exist already")
91
else:
92
print("Downloading dataset. Please while while the download and extraction processes complete")
93
# Download files from Kaggle using its API as per
94
# https://stackoverflow.com/questions/55934733/documentation-for-kaggle-api-within-python
95
96
# Kaggle authentication
97
# Remember to place the API token from Kaggle in $HOME/.kaggle
98
from kaggle.api.kaggle_api_extended import KaggleApi
99
100
api = KaggleApi()
101
api.authenticate()
102
103
# Download all files of a dataset
104
# Signature: dataset_download_files(dataset, path=None, force=False, quiet=True, unzip=False)
105
api.dataset_download_files(
106
dataset="jessicali9530/celeba-dataset", path=self.root, unzip=True, force=False, quiet=False
107
)
108
109
# Downoad the label files
110
# Signature: dataset_download_file(dataset, file_name, path=None, force=False, quiet=True)
111
for label_file in label_files:
112
api.dataset_download_file(
113
dataset="jessicali9530/celeba-dataset",
114
file_name=label_file,
115
path=self.root,
116
force=False,
117
quiet=False,
118
)
119
120
# Clear any remaining *.csv.zip files
121
files_to_delete = glob.glob(os.path.join(self.root, "*.csv.zip"))
122
for f in files_to_delete:
123
os.remove(f)
124
125
print("Done!")
126
127
def __getitem__(self, index: int):
128
X = PIL.Image.open(os.path.join(self.root, "img_align_celeba", "img_align_celeba", self.filename[index]))
129
130
target = []
131
for t in self.target_type:
132
if t == "attr":
133
target.append(self.attr[index, :])
134
# elif t == "identity":
135
# target.append(self.identity[index, 0])
136
elif t == "bbox":
137
target.append(self.bbox[index, :])
138
elif t == "landmarks":
139
target.append(self.landmarks_align[index, :])
140
else:
141
raise ValueError(f"Target type {t} is not recognized")
142
143
if self.transform is not None:
144
X = self.transform(X)
145
146
if target:
147
target = tuple(target) if len(target) > 1 else target[0]
148
149
if self.target_transform is not None:
150
target = self.target_transform(target)
151
else:
152
target = None
153
154
return X, target
155
156
def __len__(self) -> int:
157
return len(self.attr)
158
159
160
class CelebADataModule(pl.LightningDataModule):
161
def __init__(
162
self,
163
data_dir,
164
target_type="attr",
165
train_transform=None,
166
val_transform=None,
167
target_transform=None,
168
download=False,
169
batch_size=32,
170
num_workers=8,
171
):
172
173
super().__init__()
174
175
self.data_dir = data_dir
176
self.target_type = target_type
177
self.train_transform = train_transform
178
self.val_transform = val_transform
179
self.target_transform = target_transform
180
self.download = download
181
182
self.batch_size = batch_size
183
self.num_workers = num_workers
184
185
def setup(self, stage=None):
186
187
# Training dataset
188
self.celebA_trainset = CelebADataset(
189
root=self.data_dir,
190
split="train",
191
target_type=self.target_type,
192
download=self.download,
193
transform=self.train_transform,
194
target_transform=self.target_transform,
195
)
196
197
# Validation dataset
198
self.celebA_valset = CelebADataset(
199
root=self.data_dir,
200
split="valid",
201
target_type=self.target_type,
202
download=False,
203
transform=self.val_transform,
204
target_transform=self.target_transform,
205
)
206
207
# Test dataset
208
self.celebA_testset = CelebADataset(
209
root=self.data_dir,
210
split="test",
211
target_type=self.target_type,
212
download=False,
213
transform=self.val_transform,
214
target_transform=self.target_transform,
215
)
216
217
def train_dataloader(self):
218
return DataLoader(
219
self.celebA_trainset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=self.num_workers
220
)
221
222
def val_dataloader(self):
223
return DataLoader(
224
self.celebA_valset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=self.num_workers
225
)
226
227
def test_dataloader(self):
228
return DataLoader(
229
self.celebA_testset,
230
batch_size=self.batch_size,
231
shuffle=False,
232
drop_last=False,
233
num_workers=self.num_workers,
234
)
235
236