Path: blob/master/notebooks/misc/celeba_lightning.ipynb
1192 views
CelebA dataset with pytorch lightning
Install lightning
|████████████████████████████████| 819kB 7.6MB/s
|████████████████████████████████| 276kB 42.8MB/s
|████████████████████████████████| 829kB 40.2MB/s
|████████████████████████████████| 645kB 34.9MB/s
|████████████████████████████████| 10.6MB 38.5MB/s
|████████████████████████████████| 122kB 45.4MB/s
|████████████████████████████████| 1.3MB 34.9MB/s
|████████████████████████████████| 296kB 45.2MB/s
|████████████████████████████████| 143kB 57.9MB/s
Building wheel for future (setup.py) ... done
ERROR: tensorflow 2.5.0 has requirement tensorboard~=2.5, but you'll have tensorboard 2.4.1 which is incompatible.
Getting the data using torchvision.datasets
It is difficult to use torchvision.datasets.celeba, as illustrated below. The reason is that the data is hosted on the authors personal google drive account, so there is a limit to the number of times a bot can download from it. For details, see this issue and this issue). So you have to download manually.
The needed files for celeba dataset, as defined in the filelist in torchvision's CelebA class, are as follows:
The zip file contains 202,600 .jpg files and is 1.34GB in zip format.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-29-3796468581c3> in <module>()
1 from torchvision.datasets import CelebA
----> 2 dataset = CelebA(root = 'data', split = "train", download=True)
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/celeba.py in __init__(self, root, split, target_type, transform, target_transform, download)
78
79 if download:
---> 80 self.download()
81
82 if not self._check_integrity():
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/celeba.py in download(self)
150
151 for (file_id, md5, filename) in self.file_list:
--> 152 download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
153
154 with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/utils.py in download_file_from_google_drive(file_id, root, filename, md5)
230 f"and can only be overcome by trying again later."
231 )
--> 232 raise RuntimeError(msg)
233
234 _save_response_content(response, fpath)
RuntimeError: The daily quota of the file img_align_celeba.zip is exceeded and it can't be downloaded. This is a limitation of Google Drive and can only be overcome by trying again later.
There are two main solutions: download the data from kaggle, or make a symbolic link to the google drive account and copy locally to colab.
Kaggle method
Get API key from Kaggle
Follow these instructions to get a kaggle.json key file. Then upload it to colab.
Pytorch dataset and lightning datamodule
This replaces torchvision.datasets.CelebA by downloading from kaggle instead of gdrive.
Code is from https://github.com/sayantanauddy/vae_lightning/blob/main/data.py
Vanilla datamodule
Cropping datamodule
https://github.com/AntixK/PyTorch-VAE/blob/master/experiment.py#L135
GDrive method
This currently does not work: the data is copied to colab, but the torchvision.dataset.CelebA module cannot read it.
Mount your google drive
Make a symbolic link
Make a symbolic link to the dataset by going to this link and clicking "Add shortcut to drive" here:
Copy the files from gdrive to the colab machine
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-74-6688193a0934> in <module>()
1 from torchvision.datasets import CelebA
2 from torch.utils.data import DataLoader
----> 3 ds = CelebA(root = 'celeba', split = "test", download=True)
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/celeba.py in __init__(self, root, split, target_type, transform, target_transform, download)
78
79 if download:
---> 80 self.download()
81
82 if not self._check_integrity():
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/celeba.py in download(self)
150
151 for (file_id, md5, filename) in self.file_list:
--> 152 download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
153
154 with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/utils.py in download_file_from_google_drive(file_id, root, filename, md5)
230 f"and can only be overcome by trying again later."
231 )
--> 232 raise RuntimeError(msg)
233
234 _save_response_content(response, fpath)
RuntimeError: The daily quota of the file img_align_celeba.zip is exceeded and it can't be downloaded. This is a limitation of Google Drive and can only be overcome by trying again later.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-75-444dc1ab345c> in <module>()
----> 1 ds = CelebA(root = 'celeba', split = "test", download=False) #fails
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/celeba.py in __init__(self, root, split, target_type, transform, target_transform, download)
81
82 if not self._check_integrity():
---> 83 raise RuntimeError('Dataset not found or corrupted.' +
84 ' You can use download=True to download it')
85
RuntimeError: Dataset not found or corrupted. You can use download=True to download it