Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/gan/data.py
1192 views
1
from functools import partial
2
import pandas as pd
3
import os
4
import PIL
5
import glob
6
7
import torch
8
from torch.utils.data import Dataset, DataLoader, random_split
9
from torchvision import transforms, utils, io
10
from torchvision.datasets.utils import verify_str_arg
11
12
import pytorch_lightning as pl
13
14
15
class CelebADataset(Dataset):
16
"""CelebA Dataset class"""
17
18
def __init__(self, root, split="train", target_type="attr", transform=None, target_transform=None, download=False):
19
""" """
20
21
self.root = root
22
self.split = split
23
self.target_type = target_type
24
self.transform = transform
25
self.target_transform = target_transform
26
27
if isinstance(target_type, list):
28
self.target_type = target_type
29
else:
30
self.target_type = [target_type]
31
32
if not self.target_type and self.target_transform is not None:
33
raise RuntimeError("target_transform is specified but target_type is empty")
34
35
if download:
36
self.download_from_kaggle()
37
38
split_map = {
39
"train": 0,
40
"valid": 1,
41
"test": 2,
42
"all": None,
43
}
44
45
split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
46
47
fn = partial(os.path.join, self.root)
48
splits = pd.read_csv(fn("list_eval_partition.csv"), delim_whitespace=False, header=0, index_col=0)
49
# This file is not available in Kaggle
50
# identity = pd.read_csv(fn("identity_CelebA.csv"), delim_whitespace=True, header=None, index_col=0)
51
bbox = pd.read_csv(fn("list_bbox_celeba.csv"), delim_whitespace=False, header=0, index_col=0)
52
landmarks_align = pd.read_csv(
53
fn("list_landmarks_align_celeba.csv"), delim_whitespace=False, header=0, index_col=0
54
)
55
attr = pd.read_csv(fn("list_attr_celeba.csv"), delim_whitespace=False, header=0, index_col=0)
56
57
mask = slice(None) if split_ is None else (splits["partition"] == split_)
58
59
self.filename = splits[mask].index.values
60
# self.identity = torch.as_tensor(identity[mask].values)
61
self.bbox = torch.as_tensor(bbox[mask].values)
62
self.landmarks_align = torch.as_tensor(landmarks_align[mask].values)
63
self.attr = torch.as_tensor(attr[mask].values)
64
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
65
self.attr_names = list(attr.columns)
66
67
def download_from_kaggle(self):
68
69
# Annotation files will be downloaded at the end
70
label_files = [
71
"list_attr_celeba.csv",
72
"list_bbox_celeba.csv",
73
"list_eval_partition.csv",
74
"list_landmarks_align_celeba.csv",
75
]
76
77
# Check if files have been downloaded already
78
files_exist = False
79
for label_file in label_files:
80
if os.path.isfile(os.path.join(self.root, label_file)):
81
files_exist = True
82
else:
83
files_exist = False
84
85
if files_exist:
86
print("Files exist already")
87
else:
88
print("Downloading dataset. Please while while the download and extraction processes complete")
89
# Download files from Kaggle using its API as per
90
# https://stackoverflow.com/questions/55934733/documentation-for-kaggle-api-within-python
91
92
# Kaggle authentication
93
# Remember to place the API token from Kaggle in $HOME/.kaggle
94
from kaggle.api.kaggle_api_extended import KaggleApi
95
96
api = KaggleApi()
97
api.authenticate()
98
99
# Download all files of a dataset
100
# Signature: dataset_download_files(dataset, path=None, force=False, quiet=True, unzip=False)
101
api.dataset_download_files(
102
dataset="jessicali9530/celeba-dataset", path=self.root, unzip=True, force=False, quiet=False
103
)
104
105
# Downoad the label files
106
# Signature: dataset_download_file(dataset, file_name, path=None, force=False, quiet=True)
107
for label_file in label_files:
108
api.dataset_download_file(
109
dataset="jessicali9530/celeba-dataset",
110
file_name=label_file,
111
path=self.root,
112
force=False,
113
quiet=False,
114
)
115
116
# Clear any remaining *.csv.zip files
117
files_to_delete = glob.glob(os.path.join(self.root, "*.csv.zip"))
118
for f in files_to_delete:
119
os.remove(f)
120
121
print("Done!")
122
123
def __getitem__(self, index: int):
124
X = PIL.Image.open(os.path.join(self.root, "img_align_celeba", "img_align_celeba", self.filename[index]))
125
126
target = []
127
for t in self.target_type:
128
if t == "attr":
129
target.append(self.attr[index, :])
130
# elif t == "identity":
131
# target.append(self.identity[index, 0])
132
elif t == "bbox":
133
target.append(self.bbox[index, :])
134
elif t == "landmarks":
135
target.append(self.landmarks_align[index, :])
136
else:
137
raise ValueError(f"Target type {t} is not recognized")
138
139
if self.transform is not None:
140
X = self.transform(X)
141
142
if target:
143
target = tuple(target) if len(target) > 1 else target[0]
144
145
if self.target_transform is not None:
146
target = self.target_transform(target)
147
else:
148
target = None
149
150
return X, target
151
152
def __len__(self) -> int:
153
return len(self.attr)
154
155
156
class CelebADataModule(pl.LightningDataModule):
157
def __init__(
158
self,
159
data_dir,
160
target_type="attr",
161
train_transform=None,
162
val_transform=None,
163
target_transform=None,
164
download=False,
165
batch_size=32,
166
num_workers=8,
167
):
168
169
super().__init__()
170
171
self.data_dir = data_dir
172
self.target_type = target_type
173
self.train_transform = train_transform
174
self.val_transform = val_transform
175
self.target_transform = target_transform
176
self.download = download
177
178
self.batch_size = batch_size
179
self.num_workers = num_workers
180
181
def setup(self, stage=None):
182
183
# Training dataset
184
self.celebA_trainset = CelebADataset(
185
root=self.data_dir,
186
split="train",
187
target_type=self.target_type,
188
download=self.download,
189
transform=self.train_transform,
190
target_transform=self.target_transform,
191
)
192
193
# Validation dataset
194
self.celebA_valset = CelebADataset(
195
root=self.data_dir,
196
split="valid",
197
target_type=self.target_type,
198
download=False,
199
transform=self.val_transform,
200
target_transform=self.target_transform,
201
)
202
203
# Test dataset
204
self.celebA_testset = CelebADataset(
205
root=self.data_dir,
206
split="test",
207
target_type=self.target_type,
208
download=False,
209
transform=self.val_transform,
210
target_transform=self.target_transform,
211
)
212
213
def train_dataloader(self):
214
return DataLoader(
215
self.celebA_trainset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=self.num_workers
216
)
217
218
def val_dataloader(self):
219
return DataLoader(
220
self.celebA_valset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=self.num_workers
221
)
222
223
def test_dataloader(self):
224
return DataLoader(
225
self.celebA_testset,
226
batch_size=self.batch_size,
227
shuffle=False,
228
drop_last=False,
229
num_workers=self.num_workers,
230
)
231
232