Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/sample.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/utils/sample.py
6
7
import random
8
9
from torch.nn import DataParallel
10
from torch.nn.parallel import DistributedDataParallel
11
from torch import autograd
12
from numpy import linalg
13
from math import sin, cos, sqrt
14
from scipy.stats import truncnorm
15
import torch
16
import torch.nn.functional as F
17
import torch.distributions.multivariate_normal as MN
18
import numpy as np
19
20
import utils.ops as ops
21
import utils.losses as losses
22
try:
23
import utils.misc as misc
24
except AttributeError:
25
pass
26
27
28
def truncated_normal(size, threshold=1.):
29
values = truncnorm.rvs(-threshold, threshold, size=size)
30
return values
31
32
33
def sample_normal(batch_size, z_dim, truncation_factor, device):
34
if truncation_factor == -1.0:
35
latents = torch.randn(batch_size, z_dim, device=device)
36
elif truncation_factor > 0:
37
latents = torch.FloatTensor(truncated_normal([batch_size, z_dim], truncation_factor)).to(device)
38
else:
39
raise ValueError("truncated_factor must be positive.")
40
return latents
41
42
43
def sample_y(y_sampler, batch_size, num_classes, device):
44
if y_sampler == "totally_random":
45
y_fake = torch.randint(low=0, high=num_classes, size=(batch_size, ), dtype=torch.long, device=device)
46
47
elif y_sampler == "acending_some":
48
assert batch_size % 8 == 0, "The size of batches should be a multiple of 8."
49
num_classes_plot = batch_size // 8
50
indices = np.random.permutation(num_classes)[:num_classes_plot]
51
52
elif y_sampler == "acending_all":
53
batch_size = num_classes * 8
54
indices = [c for c in range(num_classes)]
55
56
elif isinstance(y_sampler, int):
57
y_fake = torch.tensor([y_sampler] * batch_size, dtype=torch.long).to(device)
58
else:
59
y_fake = None
60
61
if y_sampler in ["acending_some", "acending_all"]:
62
y_fake = []
63
for idx in indices:
64
y_fake += [idx] * 8
65
y_fake = torch.tensor(y_fake, dtype=torch.long).to(device)
66
return y_fake
67
68
69
def sample_zy(z_prior, batch_size, z_dim, num_classes, truncation_factor, y_sampler, radius, device):
70
fake_labels = sample_y(y_sampler=y_sampler, batch_size=batch_size, num_classes=num_classes, device=device)
71
batch_size = fake_labels.shape[0]
72
73
if z_prior == "gaussian":
74
zs = sample_normal(batch_size=batch_size, z_dim=z_dim, truncation_factor=truncation_factor, device=device)
75
elif z_prior == "uniform":
76
zs = torch.FloatTensor(batch_size, z_dim).uniform_(-1.0, 1.0).to(device)
77
else:
78
raise NotImplementedError
79
80
if isinstance(radius, float) and radius > 0.0:
81
if z_prior == "gaussian":
82
zs_eps = zs + radius * sample_normal(batch_size, z_dim, -1.0, device)
83
elif z_prior == "uniform":
84
zs_eps = zs + radius * torch.FloatTensor(batch_size, z_dim).uniform_(-1.0, 1.0).to(device)
85
else:
86
zs_eps = None
87
return zs, fake_labels, zs_eps
88
89
90
def generate_images(z_prior, truncation_factor, batch_size, z_dim, num_classes, y_sampler, radius, generator, discriminator,
91
is_train, LOSS, RUN, MODEL, device, is_stylegan, generator_mapping, generator_synthesis, style_mixing_p,
92
stylegan_update_emas, cal_trsp_cost):
93
if is_train:
94
truncation_factor = -1.0
95
lo_steps = LOSS.lo_steps4train
96
apply_langevin = False
97
else:
98
lo_steps = LOSS.lo_steps4eval
99
if truncation_factor != -1:
100
if is_stylegan:
101
assert 0 <= truncation_factor <= 1, "Stylegan truncation_factor must lie btw 0(strong truncation) ~ 1(no truncation)"
102
else:
103
assert 0 <= truncation_factor, "truncation_factor must lie btw 0(strong truncation) ~ inf(no truncation)"
104
105
zs, fake_labels, zs_eps = sample_zy(z_prior=z_prior,
106
batch_size=batch_size,
107
z_dim=z_dim,
108
num_classes=num_classes,
109
truncation_factor=-1 if is_stylegan else truncation_factor,
110
y_sampler=y_sampler,
111
radius=radius,
112
device=device)
113
batch_size = fake_labels.shape[0]
114
info_discrete_c, info_conti_c = None, None
115
if MODEL.info_type in ["discrete", "both"]:
116
info_discrete_c = torch.randint(MODEL.info_dim_discrete_c,(batch_size, MODEL.info_num_discrete_c), device=device)
117
zs = torch.cat((zs, F.one_hot(info_discrete_c, MODEL.info_dim_discrete_c).view(batch_size, -1)), dim=1)
118
if MODEL.info_type in ["continuous", "both"]:
119
info_conti_c = torch.rand(batch_size, MODEL.info_num_conti_c, device=device) * 2 - 1
120
zs = torch.cat((zs, info_conti_c), dim=1)
121
122
trsp_cost = None
123
if LOSS.apply_lo:
124
zs, trsp_cost = losses.latent_optimise(zs=zs,
125
fake_labels=fake_labels,
126
generator=generator,
127
discriminator=discriminator,
128
batch_size=batch_size,
129
lo_rate=LOSS.lo_rate,
130
lo_steps=lo_steps,
131
lo_alpha=LOSS.lo_alpha,
132
lo_beta=LOSS.lo_beta,
133
eval=not is_train,
134
cal_trsp_cost=cal_trsp_cost,
135
device=device)
136
if not is_train and RUN.langevin_sampling:
137
zs = langevin_sampling(zs=zs,
138
z_dim=z_dim,
139
fake_labels=fake_labels,
140
generator=generator,
141
discriminator=discriminator,
142
batch_size=batch_size,
143
langevin_rate=RUN.langevin_rate,
144
langevin_noise_std=RUN.langevin_noise_std,
145
langevin_decay=RUN.langevin_decay,
146
langevin_decay_steps=RUN.langevin_decay_steps,
147
langevin_steps=RUN.langevin_steps,
148
device=device)
149
if is_stylegan:
150
ws, fake_images = stylegan_generate_images(zs=zs,
151
fake_labels=fake_labels,
152
num_classes=num_classes,
153
style_mixing_p=style_mixing_p,
154
update_emas=stylegan_update_emas,
155
generator_mapping=generator_mapping,
156
generator_synthesis=generator_synthesis,
157
truncation_psi=truncation_factor,
158
truncation_cutoff=RUN.truncation_cutoff)
159
else:
160
fake_images = generator(zs, fake_labels, eval=not is_train)
161
ws = None
162
163
if zs_eps is not None:
164
if is_stylegan:
165
ws_eps, fake_images_eps = stylegan_generate_images(zs=zs_eps,
166
fake_labels=fake_labels,
167
num_classes=num_classes,
168
style_mixing_p=style_mixing_p,
169
update_emas=stylegan_update_emas,
170
generator_mapping=generator_mapping,
171
generator_synthesis=generator_synthesis,
172
truncation_psi=truncation_factor,
173
truncation_cutoff=RUN.truncation_cutoff)
174
else:
175
fake_images_eps = generator(zs_eps, fake_labels, eval=not is_train)
176
else:
177
fake_images_eps = None
178
return fake_images, fake_labels, fake_images_eps, trsp_cost, ws, info_discrete_c, info_conti_c
179
180
def stylegan_generate_images(zs, fake_labels, num_classes, style_mixing_p, update_emas,
181
generator_mapping, generator_synthesis, truncation_psi, truncation_cutoff):
182
one_hot_fake_labels = F.one_hot(fake_labels, num_classes=num_classes)
183
if truncation_psi == -1:
184
ws = generator_mapping(zs, one_hot_fake_labels, truncation_psi=1, update_emas=update_emas)
185
else:
186
ws = generator_mapping(zs, one_hot_fake_labels, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
187
if style_mixing_p > 0:
188
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
189
cutoff = torch.where(torch.rand([], device=ws.device) < style_mixing_p, cutoff, torch.full_like(cutoff, ws.shape[1]))
190
ws[:, cutoff:] = generator_mapping(torch.randn_like(zs), one_hot_fake_labels, update_emas=False)[:, cutoff:]
191
fake_images = generator_synthesis(ws, update_emas=update_emas)
192
return ws, fake_images
193
194
195
def langevin_sampling(zs, z_dim, fake_labels, generator, discriminator, batch_size, langevin_rate, langevin_noise_std,
196
langevin_decay, langevin_decay_steps, langevin_steps, device):
197
scaler = 1.0
198
apply_decay = langevin_decay > 0 and langevin_decay_steps > 0
199
mean = torch.zeros(z_dim, device=device)
200
prior_std = torch.eye(z_dim, device=device)
201
lgv_std = prior_std * langevin_noise_std
202
prior = MN.MultivariateNormal(loc=mean, covariance_matrix=prior_std)
203
lgv_prior = MN.MultivariateNormal(loc=mean, covariance_matrix=lgv_std)
204
for i in range(langevin_steps):
205
zs = autograd.Variable(zs, requires_grad=True)
206
fake_images = generator(zs, fake_labels, eval=True)
207
fake_dict = discriminator(fake_images, fake_labels, eval=True)
208
209
energy = -prior.log_prob(zs) - fake_dict["adv_output"]
210
z_grads = losses.cal_deriv(inputs=zs, outputs=energy, device=device)
211
212
zs = zs - 0.5 * langevin_rate * z_grads + (langevin_rate**0.5) * lgv_prior.sample([batch_size]) * scaler
213
if apply_decay and (i + 1) % langevin_decay_steps == 0:
214
langevin_rate *= langevin_decay
215
scaler *= langevin_decay
216
return zs
217
218
219
def sample_onehot(batch_size, num_classes, device="cuda"):
220
return torch.randint(low=0,
221
high=num_classes,
222
size=(batch_size, ),
223
device=device,
224
dtype=torch.int64,
225
requires_grad=False)
226
227
228
def make_mask(labels, num_classes, mask_negatives, device):
229
labels = labels.detach().cpu().numpy()
230
n_samples = labels.shape[0]
231
if mask_negatives:
232
mask_multi, target = np.zeros([num_classes, n_samples]), 1.0
233
else:
234
mask_multi, target = np.ones([num_classes, n_samples]), 0.0
235
236
for c in range(num_classes):
237
c_indices = np.where(labels == c)
238
mask_multi[c, c_indices] = target
239
return torch.tensor(mask_multi).type(torch.long).to(device)
240
241
242
def make_target_cls_sampler(dataset, target_class):
243
try:
244
targets = dataset.data.targets
245
except:
246
targets = dataset.labels
247
label_indices = []
248
for i in range(len(dataset)):
249
if targets[i] == target_class:
250
label_indices.append(i)
251
num_samples = len(label_indices)
252
sampler = torch.utils.data.sampler.SubsetRandomSampler(label_indices)
253
return num_samples, sampler
254
255