Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/vision/gan.py
781 views
1
from ..torch_core import *
2
from ..layers import *
3
from ..callback import *
4
from ..basic_data import *
5
from ..basic_train import Learner, LearnerCallback
6
from .image import Image
7
from .data import ImageList
8
9
__all__ = ['basic_critic', 'basic_generator', 'GANModule', 'GANLoss', 'GANTrainer', 'FixedGANSwitcher', 'AdaptiveGANSwitcher',
10
'GANLearner', 'NoisyItem', 'GANItemList', 'gan_critic', 'AdaptiveLoss', 'accuracy_thresh_expand',
11
'GANDiscriminativeLR']
12
13
def AvgFlatten():
14
"Takes the average of the input."
15
return Lambda(lambda x: x.mean(0).view(1))
16
17
def basic_critic(in_size:int, n_channels:int, n_features:int=64, n_extra_layers:int=0, **conv_kwargs):
18
"A basic critic for images `n_channels` x `in_size` x `in_size`."
19
layers = [conv_layer(n_channels, n_features, 4, 2, 1, leaky=0.2, norm_type=None, **conv_kwargs)]#norm_type=None?
20
cur_size, cur_ftrs = in_size//2, n_features
21
layers.append(nn.Sequential(*[conv_layer(cur_ftrs, cur_ftrs, 3, 1, leaky=0.2, **conv_kwargs) for _ in range(n_extra_layers)]))
22
while cur_size > 4:
23
layers.append(conv_layer(cur_ftrs, cur_ftrs*2, 4, 2, 1, leaky=0.2, **conv_kwargs))
24
cur_ftrs *= 2 ; cur_size //= 2
25
layers += [conv2d(cur_ftrs, 1, 4, padding=0), AvgFlatten()]
26
return nn.Sequential(*layers)
27
28
def basic_generator(in_size:int, n_channels:int, noise_sz:int=100, n_features:int=64, n_extra_layers=0, **conv_kwargs):
29
"A basic generator from `noise_sz` to images `n_channels` x `in_size` x `in_size`."
30
cur_size, cur_ftrs = 4, n_features//2
31
while cur_size < in_size: cur_size *= 2; cur_ftrs *= 2
32
layers = [conv_layer(noise_sz, cur_ftrs, 4, 1, transpose=True, **conv_kwargs)]
33
cur_size = 4
34
while cur_size < in_size // 2:
35
layers.append(conv_layer(cur_ftrs, cur_ftrs//2, 4, 2, 1, transpose=True, **conv_kwargs))
36
cur_ftrs //= 2; cur_size *= 2
37
layers += [conv_layer(cur_ftrs, cur_ftrs, 3, 1, 1, transpose=True, **conv_kwargs) for _ in range(n_extra_layers)]
38
layers += [conv2d_trans(cur_ftrs, n_channels, 4, 2, 1, bias=False), nn.Tanh()]
39
return nn.Sequential(*layers)
40
41
class GANModule(Module):
42
"Wrapper around a `generator` and a `critic` to create a GAN."
43
def __init__(self, generator:nn.Module=None, critic:nn.Module=None, gen_mode:bool=False):
44
self.gen_mode = gen_mode
45
if generator: self.generator,self.critic = generator,critic
46
47
def forward(self, *args):
48
return self.generator(*args) if self.gen_mode else self.critic(*args)
49
50
def switch(self, gen_mode:bool=None):
51
"Put the model in generator mode if `gen_mode`, in critic mode otherwise."
52
self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
53
54
class GANLoss(GANModule):
55
"Wrapper around `loss_funcC` (for the critic) and `loss_funcG` (for the generator)."
56
def __init__(self, loss_funcG:Callable, loss_funcC:Callable, gan_model:GANModule):
57
super().__init__()
58
self.loss_funcG,self.loss_funcC,self.gan_model = loss_funcG,loss_funcC,gan_model
59
60
def generator(self, output, target):
61
"Evaluate the `output` with the critic then uses `self.loss_funcG` to combine it with `target`."
62
fake_pred = self.gan_model.critic(output)
63
return self.loss_funcG(fake_pred, target, output)
64
65
def critic(self, real_pred, input):
66
"Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.loss_funcD`."
67
fake = self.gan_model.generator(input.requires_grad_(False)).requires_grad_(True)
68
fake_pred = self.gan_model.critic(fake)
69
return self.loss_funcC(real_pred, fake_pred)
70
71
class GANTrainer(LearnerCallback):
72
"Handles GAN Training."
73
_order=-20
74
def __init__(self, learn:Learner, switch_eval:bool=False, clip:float=None, beta:float=0.98, gen_first:bool=False,
75
show_img:bool=True):
76
super().__init__(learn)
77
self.switch_eval,self.clip,self.beta,self.gen_first,self.show_img = switch_eval,clip,beta,gen_first,show_img
78
self.generator,self.critic = self.model.generator,self.model.critic
79
80
def _set_trainable(self):
81
train_model = self.generator if self.gen_mode else self.critic
82
loss_model = self.generator if not self.gen_mode else self.critic
83
requires_grad(train_model, True)
84
requires_grad(loss_model, False)
85
if self.switch_eval:
86
train_model.train()
87
loss_model.eval()
88
89
def on_train_begin(self, **kwargs):
90
"Create the optimizers for the generator and critic if necessary, initialize smootheners."
91
if not getattr(self,'opt_gen',None):
92
self.opt_gen = self.opt.new([nn.Sequential(*flatten_model(self.generator))])
93
else: self.opt_gen.lr,self.opt_gen.wd = self.opt.lr,self.opt.wd
94
if not getattr(self,'opt_critic',None):
95
self.opt_critic = self.opt.new([nn.Sequential(*flatten_model(self.critic))])
96
else: self.opt_critic.lr,self.opt_critic.wd = self.opt.lr,self.opt.wd
97
self.gen_mode = self.gen_first
98
self.switch(self.gen_mode)
99
self.closses,self.glosses = [],[]
100
self.smoothenerG,self.smoothenerC = SmoothenValue(self.beta),SmoothenValue(self.beta)
101
#self.recorder.no_val=True
102
self.recorder.add_metric_names(['gen_loss', 'disc_loss'])
103
self.imgs,self.titles = [],[]
104
105
def on_train_end(self, **kwargs):
106
"Switch in generator mode for showing results."
107
self.switch(gen_mode=True)
108
109
def on_batch_begin(self, last_input, last_target, **kwargs):
110
"Clamp the weights with `self.clip` if it's not None, return the correct input."
111
if self.clip is not None:
112
for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
113
return {'last_input':last_input,'last_target':last_target} if self.gen_mode else {'last_input':last_target,'last_target':last_input}
114
115
def on_backward_begin(self, last_loss, last_output, **kwargs):
116
"Record `last_loss` in the proper list."
117
last_loss = last_loss.detach().cpu()
118
if self.gen_mode:
119
self.smoothenerG.add_value(last_loss)
120
self.glosses.append(self.smoothenerG.smooth)
121
self.last_gen = last_output.detach().cpu()
122
else:
123
self.smoothenerC.add_value(last_loss)
124
self.closses.append(self.smoothenerC.smooth)
125
126
def on_epoch_begin(self, epoch, **kwargs):
127
"Put the critic or the generator back to eval if necessary."
128
self.switch(self.gen_mode)
129
130
def on_epoch_end(self, pbar, epoch, last_metrics, **kwargs):
131
"Put the various losses in the recorder and show a sample image."
132
if not hasattr(self, 'last_gen') or not self.show_img: return
133
data = self.learn.data
134
img = self.last_gen[0]
135
norm = getattr(data,'norm',False)
136
if norm and norm.keywords.get('do_y',False): img = data.denorm(img)
137
img = data.train_ds.y.reconstruct(img)
138
self.imgs.append(img)
139
self.titles.append(f'Epoch {epoch}')
140
pbar.show_imgs(self.imgs, self.titles)
141
return add_metrics(last_metrics, [getattr(self.smoothenerG,'smooth',None),getattr(self.smoothenerC,'smooth',None)])
142
143
def switch(self, gen_mode:bool=None):
144
"Switch the model, if `gen_mode` is provided, in the desired mode."
145
self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
146
self.opt.opt = self.opt_gen.opt if self.gen_mode else self.opt_critic.opt
147
self._set_trainable()
148
self.model.switch(gen_mode)
149
self.loss_func.switch(gen_mode)
150
151
class FixedGANSwitcher(LearnerCallback):
152
"Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator."
153
def __init__(self, learn:Learner, n_crit:Union[int,Callable]=1, n_gen:Union[int,Callable]=1):
154
super().__init__(learn)
155
self.n_crit,self.n_gen = n_crit,n_gen
156
157
def on_train_begin(self, **kwargs):
158
"Initiate the iteration counts."
159
self.n_c,self.n_g = 0,0
160
161
def on_batch_end(self, iteration, **kwargs):
162
"Switch the model if necessary."
163
if self.learn.gan_trainer.gen_mode:
164
self.n_g += 1
165
n_iter,n_in,n_out = self.n_gen,self.n_c,self.n_g
166
else:
167
self.n_c += 1
168
n_iter,n_in,n_out = self.n_crit,self.n_g,self.n_c
169
target = n_iter if isinstance(n_iter, int) else n_iter(n_in)
170
if target == n_out:
171
self.learn.gan_trainer.switch()
172
self.n_c,self.n_g = 0,0
173
174
@dataclass
175
class AdaptiveGANSwitcher(LearnerCallback):
176
"Switcher that goes back to generator/critic when the loss goes below `gen_thresh`/`crit_thresh`."
177
def __init__(self, learn:Learner, gen_thresh:float=None, critic_thresh:float=None):
178
super().__init__(learn)
179
self.gen_thresh,self.critic_thresh = gen_thresh,critic_thresh
180
181
def on_batch_end(self, last_loss, **kwargs):
182
"Switch the model if necessary."
183
if self.gan_trainer.gen_mode:
184
if self.gen_thresh is None: self.gan_trainer.switch()
185
elif last_loss < self.gen_thresh: self.gan_trainer.switch()
186
else:
187
if self.critic_thresh is None: self.gan_trainer.switch()
188
elif last_loss < self.critic_thresh: self.gan_trainer.switch()
189
190
def gan_loss_from_func(loss_gen, loss_crit, weights_gen:Tuple[float,float]=None):
191
"Define loss functions for a GAN from `loss_gen` and `loss_crit`."
192
def _loss_G(fake_pred, output, target, weights_gen=weights_gen):
193
ones = fake_pred.new_ones(fake_pred.shape[0])
194
weights_gen = ifnone(weights_gen, (1.,1.))
195
return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)
196
197
def _loss_C(real_pred, fake_pred):
198
ones = real_pred.new_ones (real_pred.shape[0])
199
zeros = fake_pred.new_zeros(fake_pred.shape[0])
200
return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2
201
202
return _loss_G, _loss_C
203
204
class GANLearner(Learner):
205
"A `Learner` suitable for GANs."
206
def __init__(self, data:DataBunch, generator:nn.Module, critic:nn.Module, gen_loss_func:LossFunction,
207
crit_loss_func:LossFunction, switcher:Callback=None, gen_first:bool=False, switch_eval:bool=True,
208
show_img:bool=True, clip:float=None, **learn_kwargs):
209
gan = GANModule(generator, critic)
210
loss_func = GANLoss(gen_loss_func, crit_loss_func, gan)
211
switcher = ifnone(switcher, partial(FixedGANSwitcher, n_crit=5, n_gen=1))
212
super().__init__(data, gan, loss_func=loss_func, callback_fns=[switcher], **learn_kwargs)
213
trainer = GANTrainer(self, clip=clip, switch_eval=switch_eval, show_img=show_img)
214
self.gan_trainer = trainer
215
self.callbacks.append(trainer)
216
217
@classmethod
218
def from_learners(cls, learn_gen:Learner, learn_crit:Learner, switcher:Callback=None,
219
weights_gen:Tuple[float,float]=None, **learn_kwargs):
220
"Create a GAN from `learn_gen` and `learn_crit`."
221
losses = gan_loss_from_func(learn_gen.loss_func, learn_crit.loss_func, weights_gen=weights_gen)
222
return cls(learn_gen.data, learn_gen.model, learn_crit.model, *losses, switcher=switcher, **learn_kwargs)
223
224
@classmethod
225
def wgan(cls, data:DataBunch, generator:nn.Module, critic:nn.Module, switcher:Callback=None, clip:float=0.01, **learn_kwargs):
226
"Create a WGAN from `data`, `generator` and `critic`."
227
return cls(data, generator, critic, NoopLoss(), WassersteinLoss(), switcher=switcher, clip=clip, **learn_kwargs)
228
229
class NoisyItem(ItemBase):
230
"An random `ItemBase` of size `noise_sz`."
231
def __init__(self, noise_sz): self.obj,self.data = noise_sz,torch.randn(noise_sz, 1, 1)
232
def __str__(self): return ''
233
def apply_tfms(self, tfms, **kwargs): return self
234
235
class GANItemList(ImageList):
236
"`ItemList` suitable for GANs."
237
_label_cls = ImageList
238
239
def __init__(self, items, noise_sz:int=100, **kwargs):
240
super().__init__(items, **kwargs)
241
self.noise_sz = noise_sz
242
self.copy_new.append('noise_sz')
243
244
def get(self, i): return NoisyItem(self.noise_sz)
245
def reconstruct(self, t): return NoisyItem(t.size(0))
246
247
def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
248
"Shows `ys` (target images) on a figure of `figsize`."
249
super().show_xys(ys, xs, imgsize=imgsize, figsize=figsize, **kwargs)
250
251
def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
252
"Shows `zs` (generated images) on a figure of `figsize`."
253
super().show_xys(zs, xs, imgsize=imgsize, figsize=figsize, **kwargs)
254
255
_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
256
257
def _conv(ni:int, nf:int, ks:int=3, stride:int=1, **kwargs):
258
return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
259
260
def gan_critic(n_channels:int=3, nf:int=128, n_blocks:int=3, p:int=0.15):
261
"Critic to train a `GAN`."
262
layers = [
263
_conv(n_channels, nf, ks=4, stride=2),
264
nn.Dropout2d(p/2),
265
res_block(nf, dense=True,**_conv_args)]
266
nf *= 2 # after dense block
267
for i in range(n_blocks):
268
layers += [
269
nn.Dropout2d(p),
270
_conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
271
nf *= 2
272
layers += [
273
_conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
274
Flatten()]
275
return nn.Sequential(*layers)
276
277
class GANDiscriminativeLR(LearnerCallback):
278
"`Callback` that handles multiplying the learning rate by `mult_lr` for the critic."
279
def __init__(self, learn:Learner, mult_lr:float = 5.):
280
super().__init__(learn)
281
self.mult_lr = mult_lr
282
283
def on_batch_begin(self, train, **kwargs):
284
"Multiply the current lr if necessary."
285
if not self.learn.gan_trainer.gen_mode and train: self.learn.opt.lr *= self.mult_lr
286
287
def on_step_end(self, **kwargs):
288
"Put the LR back to its value if necessary."
289
if not self.learn.gan_trainer.gen_mode: self.learn.opt.lr /= self.mult_lr
290
291
class AdaptiveLoss(Module):
292
"Expand the `target` to match the `output` size before applying `crit`."
293
def __init__(self, crit):
294
self.crit = crit
295
296
def forward(self, output, target):
297
return self.crit(output, target[:,None].expand_as(output).float())
298
299
def accuracy_thresh_expand(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor:
300
"Compute accuracy after expanding `y_true` to the size of `y_pred`."
301
if sigmoid: y_pred = y_pred.sigmoid()
302
return ((y_pred>thresh)==y_true[:,None].expand_as(y_pred).byte()).float().mean()
303
304