Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/data_util.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/data_util.py
6
7
import os
8
import random
9
10
from torch.utils.data import Dataset
11
from torchvision.datasets import CIFAR10, CIFAR100
12
from torchvision.datasets import ImageFolder
13
from torchvision.transforms import InterpolationMode
14
from scipy import io
15
from PIL import ImageOps, Image
16
import torch
17
import torchvision.transforms as transforms
18
import h5py as h5
19
import numpy as np
20
21
22
resizer_collection = {"nearest": InterpolationMode.NEAREST,
23
"box": InterpolationMode.BOX,
24
"bilinear": InterpolationMode.BILINEAR,
25
"hamming": InterpolationMode.HAMMING,
26
"bicubic": InterpolationMode.BICUBIC,
27
"lanczos": InterpolationMode.LANCZOS}
28
29
class RandomCropLongEdge(object):
30
"""
31
this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch
32
MIT License
33
Copyright (c) 2019 Andy Brock
34
"""
35
def __call__(self, img):
36
size = (min(img.size), min(img.size))
37
# Only step forward along this edge if it's the long edge
38
i = (0 if size[0] == img.size[0] else np.random.randint(low=0, high=img.size[0] - size[0]))
39
j = (0 if size[1] == img.size[1] else np.random.randint(low=0, high=img.size[1] - size[1]))
40
return transforms.functional.crop(img, j, i, size[0], size[1])
41
42
def __repr__(self):
43
return self.__class__.__name__
44
45
46
class CenterCropLongEdge(object):
47
"""
48
this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch
49
MIT License
50
Copyright (c) 2019 Andy Brock
51
"""
52
def __call__(self, img):
53
return transforms.functional.center_crop(img, min(img.size))
54
55
def __repr__(self):
56
return self.__class__.__name__
57
58
59
class Dataset_(Dataset):
60
def __init__(self,
61
data_name,
62
data_dir,
63
train,
64
crop_long_edge=False,
65
resize_size=None,
66
resizer="lanczos",
67
random_flip=False,
68
normalize=True,
69
hdf5_path=None,
70
load_data_in_memory=False):
71
super(Dataset_, self).__init__()
72
self.data_name = data_name
73
self.data_dir = data_dir
74
self.train = train
75
self.random_flip = random_flip
76
self.normalize = normalize
77
self.hdf5_path = hdf5_path
78
self.load_data_in_memory = load_data_in_memory
79
self.trsf_list = []
80
81
if self.hdf5_path is None:
82
if crop_long_edge:
83
self.trsf_list += [CenterCropLongEdge()]
84
if resize_size is not None and resizer != "wo_resize":
85
self.trsf_list += [transforms.Resize(resize_size, interpolation=resizer_collection[resizer])]
86
else:
87
self.trsf_list += [transforms.ToPILImage()]
88
89
if self.random_flip:
90
self.trsf_list += [transforms.RandomHorizontalFlip()]
91
92
if self.normalize:
93
self.trsf_list += [transforms.ToTensor()]
94
self.trsf_list += [transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
95
else:
96
self.trsf_list += [transforms.PILToTensor()]
97
98
self.trsf = transforms.Compose(self.trsf_list)
99
100
self.load_dataset()
101
102
def load_dataset(self):
103
if self.hdf5_path is not None:
104
with h5.File(self.hdf5_path, "r") as f:
105
data, labels = f["imgs"], f["labels"]
106
self.num_dataset = data.shape[0]
107
if self.load_data_in_memory:
108
print("Load {path} into memory.".format(path=self.hdf5_path))
109
self.data = data[:]
110
self.labels = labels[:]
111
return
112
113
if self.data_name == "CIFAR10":
114
self.data = CIFAR10(root=self.data_dir, train=self.train, download=True)
115
116
elif self.data_name == "CIFAR100":
117
self.data = CIFAR100(root=self.data_dir, train=self.train, download=True)
118
else:
119
mode = "train" if self.train == True else "valid"
120
root = os.path.join(self.data_dir, mode)
121
self.data = ImageFolder(root=root)
122
123
def _get_hdf5(self, index):
124
with h5.File(self.hdf5_path, "r") as f:
125
return f["imgs"][index], f["labels"][index]
126
127
def __len__(self):
128
if self.hdf5_path is None:
129
num_dataset = len(self.data)
130
else:
131
num_dataset = self.num_dataset
132
return num_dataset
133
134
def __getitem__(self, index):
135
if self.hdf5_path is None:
136
img, label = self.data[index]
137
else:
138
if self.load_data_in_memory:
139
img, label = self.data[index], self.labels[index]
140
else:
141
img, label = self._get_hdf5(index)
142
return self.trsf(img), int(label)
143
144