Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fid/inception.py
781 views
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
from torchvision import models
5
6
try:
7
from torchvision.models.utils import load_state_dict_from_url
8
except ImportError:
9
from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
11
# Inception weights ported to Pytorch from
12
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
14
15
16
class InceptionV3(nn.Module):
17
"""Pretrained InceptionV3 network returning feature maps"""
18
19
# Index of default block of inception to return,
20
# corresponds to output of final average pooling
21
DEFAULT_BLOCK_INDEX = 3
22
23
# Maps feature dimensionality to their output blocks indices
24
BLOCK_INDEX_BY_DIM = {
25
64: 0, # First max pooling features
26
192: 1, # Second max pooling featurs
27
768: 2, # Pre-aux classifier features
28
2048: 3, # Final average pooling features
29
}
30
31
def __init__(
32
self,
33
output_blocks=[DEFAULT_BLOCK_INDEX],
34
resize_input=True,
35
normalize_input=True,
36
requires_grad=False,
37
use_fid_inception=True,
38
):
39
"""Build pretrained InceptionV3
40
41
Parameters
42
----------
43
output_blocks : list of int
44
Indices of blocks to return features of. Possible values are:
45
- 0: corresponds to output of first max pooling
46
- 1: corresponds to output of second max pooling
47
- 2: corresponds to output which is fed to aux classifier
48
- 3: corresponds to output of final average pooling
49
resize_input : bool
50
If true, bilinearly resizes input to width and height 299 before
51
feeding input to model. As the network without fully connected
52
layers is fully convolutional, it should be able to handle inputs
53
of arbitrary size, so resizing might not be strictly needed
54
normalize_input : bool
55
If true, scales the input from range (0, 1) to the range the
56
pretrained Inception network expects, namely (-1, 1)
57
requires_grad : bool
58
If true, parameters of the model require gradients. Possibly useful
59
for finetuning the network
60
use_fid_inception : bool
61
If true, uses the pretrained Inception model used in Tensorflow's
62
FID implementation. If false, uses the pretrained Inception model
63
available in torchvision. The FID Inception model has different
64
weights and a slightly different structure from torchvision's
65
Inception model. If you want to compute FID scores, you are
66
strongly advised to set this parameter to true to get comparable
67
results.
68
"""
69
super(InceptionV3, self).__init__()
70
71
self.resize_input = resize_input
72
self.normalize_input = normalize_input
73
self.output_blocks = sorted(output_blocks)
74
self.last_needed_block = max(output_blocks)
75
76
assert self.last_needed_block <= 3, 'Last possible output block index is 3'
77
78
self.blocks = nn.ModuleList()
79
80
if use_fid_inception:
81
inception = fid_inception_v3()
82
else:
83
inception = models.inception_v3(pretrained=True)
84
85
# Block 0: input to maxpool1
86
block0 = [
87
inception.Conv2d_1a_3x3,
88
inception.Conv2d_2a_3x3,
89
inception.Conv2d_2b_3x3,
90
nn.MaxPool2d(kernel_size=3, stride=2),
91
]
92
self.blocks.append(nn.Sequential(*block0))
93
94
# Block 1: maxpool1 to maxpool2
95
if self.last_needed_block >= 1:
96
block1 = [
97
inception.Conv2d_3b_1x1,
98
inception.Conv2d_4a_3x3,
99
nn.MaxPool2d(kernel_size=3, stride=2),
100
]
101
self.blocks.append(nn.Sequential(*block1))
102
103
# Block 2: maxpool2 to aux classifier
104
if self.last_needed_block >= 2:
105
block2 = [
106
inception.Mixed_5b,
107
inception.Mixed_5c,
108
inception.Mixed_5d,
109
inception.Mixed_6a,
110
inception.Mixed_6b,
111
inception.Mixed_6c,
112
inception.Mixed_6d,
113
inception.Mixed_6e,
114
]
115
self.blocks.append(nn.Sequential(*block2))
116
117
# Block 3: aux classifier to final avgpool
118
if self.last_needed_block >= 3:
119
block3 = [
120
inception.Mixed_7a,
121
inception.Mixed_7b,
122
inception.Mixed_7c,
123
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
124
]
125
self.blocks.append(nn.Sequential(*block3))
126
127
for param in self.parameters():
128
param.requires_grad = requires_grad
129
130
def forward(self, inp):
131
"""Get Inception feature maps
132
133
Parameters
134
----------
135
inp : torch.autograd.Variable
136
Input tensor of shape Bx3xHxW. Values are expected to be in
137
range (0, 1)
138
139
Returns
140
-------
141
List of torch.autograd.Variable, corresponding to the selected output
142
block, sorted ascending by index
143
"""
144
outp = []
145
x = inp
146
147
if self.resize_input:
148
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
149
150
if self.normalize_input:
151
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
152
153
for idx, block in enumerate(self.blocks):
154
x = block(x)
155
if idx in self.output_blocks:
156
outp.append(x)
157
158
if idx == self.last_needed_block:
159
break
160
161
return outp
162
163
164
def fid_inception_v3():
165
"""Build pretrained Inception model for FID computation
166
167
The Inception model for FID computation uses a different set of weights
168
and has a slightly different structure than torchvision's Inception.
169
170
This method first constructs torchvision's Inception and then patches the
171
necessary parts that are different in the FID Inception model.
172
"""
173
inception = models.inception_v3(
174
num_classes=1008, aux_logits=False, pretrained=False
175
)
176
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
177
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
178
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
179
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
180
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
181
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
182
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
183
inception.Mixed_7b = FIDInceptionE_1(1280)
184
inception.Mixed_7c = FIDInceptionE_2(2048)
185
186
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
187
inception.load_state_dict(state_dict)
188
return inception
189
190
191
class FIDInceptionA(models.inception.InceptionA):
192
"""InceptionA block patched for FID computation"""
193
194
def __init__(self, in_channels, pool_features):
195
super(FIDInceptionA, self).__init__(in_channels, pool_features)
196
197
def forward(self, x):
198
branch1x1 = self.branch1x1(x)
199
200
branch5x5 = self.branch5x5_1(x)
201
branch5x5 = self.branch5x5_2(branch5x5)
202
203
branch3x3dbl = self.branch3x3dbl_1(x)
204
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
205
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
206
207
# Patch: Tensorflow's average pool does not use the padded zero's in
208
# its average calculation
209
branch_pool = F.avg_pool2d(
210
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
211
)
212
branch_pool = self.branch_pool(branch_pool)
213
214
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
215
return torch.cat(outputs, 1)
216
217
218
class FIDInceptionC(models.inception.InceptionC):
219
"""InceptionC block patched for FID computation"""
220
221
def __init__(self, in_channels, channels_7x7):
222
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
223
224
def forward(self, x):
225
branch1x1 = self.branch1x1(x)
226
227
branch7x7 = self.branch7x7_1(x)
228
branch7x7 = self.branch7x7_2(branch7x7)
229
branch7x7 = self.branch7x7_3(branch7x7)
230
231
branch7x7dbl = self.branch7x7dbl_1(x)
232
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
233
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
234
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
235
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
236
237
# Patch: Tensorflow's average pool does not use the padded zero's in
238
# its average calculation
239
branch_pool = F.avg_pool2d(
240
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
241
)
242
branch_pool = self.branch_pool(branch_pool)
243
244
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
245
return torch.cat(outputs, 1)
246
247
248
class FIDInceptionE_1(models.inception.InceptionE):
249
"""First InceptionE block patched for FID computation"""
250
251
def __init__(self, in_channels):
252
super(FIDInceptionE_1, self).__init__(in_channels)
253
254
def forward(self, x):
255
branch1x1 = self.branch1x1(x)
256
257
branch3x3 = self.branch3x3_1(x)
258
branch3x3 = [
259
self.branch3x3_2a(branch3x3),
260
self.branch3x3_2b(branch3x3),
261
]
262
branch3x3 = torch.cat(branch3x3, 1)
263
264
branch3x3dbl = self.branch3x3dbl_1(x)
265
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
266
branch3x3dbl = [
267
self.branch3x3dbl_3a(branch3x3dbl),
268
self.branch3x3dbl_3b(branch3x3dbl),
269
]
270
branch3x3dbl = torch.cat(branch3x3dbl, 1)
271
272
# Patch: Tensorflow's average pool does not use the padded zero's in
273
# its average calculation
274
branch_pool = F.avg_pool2d(
275
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
276
)
277
branch_pool = self.branch_pool(branch_pool)
278
279
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
280
return torch.cat(outputs, 1)
281
282
283
class FIDInceptionE_2(models.inception.InceptionE):
284
"""Second InceptionE block patched for FID computation"""
285
286
def __init__(self, in_channels):
287
super(FIDInceptionE_2, self).__init__(in_channels)
288
289
def forward(self, x):
290
branch1x1 = self.branch1x1(x)
291
292
branch3x3 = self.branch3x3_1(x)
293
branch3x3 = [
294
self.branch3x3_2a(branch3x3),
295
self.branch3x3_2b(branch3x3),
296
]
297
branch3x3 = torch.cat(branch3x3, 1)
298
299
branch3x3dbl = self.branch3x3dbl_1(x)
300
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
301
branch3x3dbl = [
302
self.branch3x3dbl_3a(branch3x3dbl),
303
self.branch3x3dbl_3b(branch3x3dbl),
304
]
305
branch3x3dbl = torch.cat(branch3x3dbl, 1)
306
307
# Patch: The FID Inception model uses max pooling instead of average
308
# pooling. This is likely an error in this specific Inception
309
# implementation, as other Inception models use average pooling here
310
# (which matches the description in the paper).
311
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
312
branch_pool = self.branch_pool(branch_pool)
313
314
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
315
return torch.cat(outputs, 1)
316
317