Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/vision/cyclegan.py
781 views
1
from ..torch_core import *
2
from ..layers import *
3
from ..callback import *
4
from ..basic_train import Learner, LearnerCallback
5
6
__all__ = ['CycleGAN', 'CycleGanLoss', 'AdaptiveLoss', 'CycleGANTrainer']
7
8
def convT_norm_relu(ch_in:int, ch_out:int, norm_layer:nn.Module, ks:int=3, stride:int=2, bias:bool=True):
9
return [nn.ConvTranspose2d(ch_in, ch_out, kernel_size=ks, stride=stride, padding=1, output_padding=1, bias=bias),
10
norm_layer(ch_out), nn.ReLU(True)]
11
12
def pad_conv_norm_relu(ch_in:int, ch_out:int, pad_mode:str, norm_layer:nn.Module, ks:int=3, bias:bool=True,
13
pad=1, stride:int=1, activ:bool=True, init:Callable=nn.init.kaiming_normal_)->List[nn.Module]:
14
layers = []
15
if pad_mode == 'reflection': layers.append(nn.ReflectionPad2d(pad))
16
elif pad_mode == 'border': layers.append(nn.ReplicationPad2d(pad))
17
p = pad if pad_mode == 'zeros' else 0
18
conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=p, stride=stride, bias=bias)
19
if init:
20
init(conv.weight)
21
if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)
22
layers += [conv, norm_layer(ch_out)]
23
if activ: layers.append(nn.ReLU(inplace=True))
24
return layers
25
26
class ResnetBlock(Module):
27
def __init__(self, dim:int, pad_mode:str='reflection', norm_layer:nn.Module=None, dropout:float=0., bias:bool=True):
28
assert pad_mode in ['zeros', 'reflection', 'border'], f'padding {pad_mode} not implemented.'
29
norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
30
layers = pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias)
31
if dropout != 0: layers.append(nn.Dropout(dropout))
32
layers += pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias, activ=False)
33
self.conv_block = nn.Sequential(*layers)
34
35
def forward(self, x): return x + self.conv_block(x)
36
37
def resnet_generator(ch_in:int, ch_out:int, n_ftrs:int=64, norm_layer:nn.Module=None,
38
dropout:float=0., n_blocks:int=6, pad_mode:str='reflection')->nn.Module:
39
norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
40
bias = (norm_layer == nn.InstanceNorm2d)
41
layers = pad_conv_norm_relu(ch_in, n_ftrs, 'reflection', norm_layer, pad=3, ks=7, bias=bias)
42
for i in range(2):
43
layers += pad_conv_norm_relu(n_ftrs, n_ftrs *2, 'zeros', norm_layer, stride=2, bias=bias)
44
n_ftrs *= 2
45
layers += [ResnetBlock(n_ftrs, pad_mode, norm_layer, dropout, bias) for _ in range(n_blocks)]
46
for i in range(2):
47
layers += convT_norm_relu(n_ftrs, n_ftrs//2, norm_layer, bias=bias)
48
n_ftrs //= 2
49
layers += [nn.ReflectionPad2d(3), nn.Conv2d(n_ftrs, ch_out, kernel_size=7, padding=0), nn.Tanh()]
50
return nn.Sequential(*layers)
51
52
def conv_norm_lr(ch_in:int, ch_out:int, norm_layer:nn.Module=None, ks:int=3, bias:bool=True, pad:int=1, stride:int=1,
53
activ:bool=True, slope:float=0.2, init:Callable=nn.init.kaiming_normal_)->List[nn.Module]:
54
conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=pad, stride=stride, bias=bias)
55
if init:
56
init(conv.weight)
57
if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)
58
layers = [conv]
59
if norm_layer is not None: layers.append(norm_layer(ch_out))
60
if activ: layers.append(nn.LeakyReLU(slope, inplace=True))
61
return layers
62
63
def critic(ch_in:int, n_ftrs:int=64, n_layers:int=3, norm_layer:nn.Module=None, sigmoid:bool=False)->nn.Module:
64
norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
65
bias = (norm_layer == nn.InstanceNorm2d)
66
layers = conv_norm_lr(ch_in, n_ftrs, ks=4, stride=2, pad=1)
67
for i in range(n_layers-1):
68
new_ftrs = 2*n_ftrs if i <= 3 else n_ftrs
69
layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=2, pad=1, bias=bias)
70
n_ftrs = new_ftrs
71
new_ftrs = 2*n_ftrs if n_layers <=3 else n_ftrs
72
layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=1, pad=1, bias=bias)
73
layers.append(nn.Conv2d(new_ftrs, 1, kernel_size=4, stride=1, padding=1))
74
if sigmoid: layers.append(nn.Sigmoid())
75
return nn.Sequential(*layers)
76
77
class CycleGAN(Module):
78
79
def __init__(self, ch_in:int, ch_out:int, n_features:int=64, disc_layers:int=3, gen_blocks:int=6, lsgan:bool=True,
80
drop:float=0., norm_layer:nn.Module=None):
81
self.D_A = critic(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)
82
self.D_B = critic(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)
83
self.G_A = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)
84
self.G_B = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)
85
#G_A: takes real input B and generates fake input A
86
#G_B: takes real input A and generates fake input B
87
#D_A: trained to make the difference between real input A and fake input A
88
#D_B: trained to make the difference between real input B and fake input B
89
90
def forward(self, real_A, real_B):
91
fake_A, fake_B = self.G_A(real_B), self.G_B(real_A)
92
if not self.training: return torch.cat([fake_A[:,None],fake_B[:,None]], 1)
93
idt_A, idt_B = self.G_A(real_A), self.G_B(real_B)
94
return [fake_A, fake_B, idt_A, idt_B]
95
96
class AdaptiveLoss(Module):
97
def __init__(self, crit): self.crit = crit
98
99
def forward(self, output, target:bool):
100
targ = output.new_ones(*output.size()) if target else output.new_zeros(*output.size())
101
return self.crit(output, targ)
102
103
class CycleGanLoss(Module):
104
def __init__(self, cgan:nn.Module, lambda_A:float=10., lambda_B:float=10, lambda_idt:float=0.5, lsgan:bool=True):
105
self.cgan,self.l_A,self.l_B,self.l_idt = cgan,lambda_A,lambda_B,lambda_idt
106
#self.crit = F.mse_loss if lsgan else F.binary_cross_entropy
107
self.crit = AdaptiveLoss(F.mse_loss if lsgan else F.binary_cross_entropy)
108
109
def set_input(self, input):
110
self.real_A,self.real_B = input
111
112
def forward(self, output, target):
113
fake_A, fake_B, idt_A, idt_B = output
114
#Generators should return identity on the datasets they try to convert to
115
idt_loss = self.l_idt * (self.l_B * F.l1_loss(idt_A, self.real_B) + self.l_A * F.l1_loss(idt_B, self.real_A))
116
#Generators are trained to trick the critics so the following should be ones
117
gen_loss = self.crit(self.cgan.D_A(fake_A), True) + self.crit(self.cgan.D_B(fake_B), True)
118
#Cycle loss
119
cycle_loss = self.l_A * F.l1_loss(self.cgan.G_A(fake_B), self.real_A)
120
cycle_loss += self.l_B * F.l1_loss(self.cgan.G_B(fake_A), self.real_B)
121
self.metrics = [idt_loss, gen_loss, cycle_loss]
122
return idt_loss + gen_loss + cycle_loss
123
124
class CycleGANTrainer(LearnerCallback):
125
"`LearnerCallback` that handles cycleGAN Training."
126
_order=-20
127
def _set_trainable(self, D_A=False, D_B=False):
128
gen = (not D_A) and (not D_B)
129
requires_grad(self.learn.model.G_A, gen)
130
requires_grad(self.learn.model.G_B, gen)
131
requires_grad(self.learn.model.D_A, D_A)
132
requires_grad(self.learn.model.D_B, D_B)
133
if not gen:
134
self.opt_D_A.lr, self.opt_D_A.mom = self.learn.opt.lr, self.learn.opt.mom
135
self.opt_D_A.wd, self.opt_D_A.beta = self.learn.opt.wd, self.learn.opt.beta
136
self.opt_D_B.lr, self.opt_D_B.mom = self.learn.opt.lr, self.learn.opt.mom
137
self.opt_D_B.wd, self.opt_D_B.beta = self.learn.opt.wd, self.learn.opt.beta
138
139
def on_train_begin(self, **kwargs):
140
"Create the various optimizers."
141
self.G_A,self.G_B = self.learn.model.G_A,self.learn.model.G_B
142
self.D_A,self.D_B = self.learn.model.D_A,self.learn.model.D_B
143
self.crit = self.learn.loss_func.crit
144
self.opt_G = self.learn.opt.new([nn.Sequential(*flatten_model(self.G_A), *flatten_model(self.G_B))])
145
self.opt_D_A = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_A))])
146
self.opt_D_B = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_B))])
147
self.learn.opt.opt = self.opt_G.opt
148
self._set_trainable()
149
self.names = ['idt_loss', 'gen_loss', 'cyc_loss', 'da_loss', 'db_loss']
150
self.learn.recorder.no_val=True
151
self.learn.recorder.add_metric_names(self.names)
152
self.smootheners = {n:SmoothenValue(0.98) for n in self.names}
153
154
def on_batch_begin(self, last_input, **kwargs):
155
"Register the `last_input` in the loss function."
156
self.learn.loss_func.set_input(last_input)
157
158
def on_batch_end(self, last_input, last_output, **kwargs):
159
"Steps through the generators then each of the critics."
160
self.G_A.zero_grad(); self.G_B.zero_grad()
161
fake_A, fake_B = last_output[0].detach(), last_output[1].detach()
162
real_A, real_B = last_input
163
self._set_trainable(D_A=True)
164
self.D_A.zero_grad()
165
loss_D_A = 0.5 * (self.crit(self.D_A(real_A), True) + self.crit(self.D_A(fake_A), False))
166
loss_D_A.backward()
167
self.opt_D_A.step()
168
self._set_trainable(D_B=True)
169
self.D_B.zero_grad()
170
loss_D_B = 0.5 * (self.crit(self.D_B(real_B), True) + self.crit(self.D_B(fake_B), False))
171
loss_D_B.backward()
172
self.opt_D_B.step()
173
self._set_trainable()
174
metrics = self.learn.loss_func.metrics + [loss_D_A, loss_D_B]
175
for n,m in zip(self.names,metrics): self.smootheners[n].add_value(m)
176
177
def on_epoch_end(self, last_metrics, **kwargs):
178
"Put the various losses in the recorder."
179
return add_metrics(last_metrics, [s.smooth for k,s in self.smootheners.items()])
180
181
182