Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/metrics/inception_net.py
809 views
1
from torchvision import models
2
import torch
3
import torch.nn as nn
4
import torch.nn.functional as F
5
6
try:
7
from torchvision.models.utils import load_state_dict_from_url
8
except ImportError:
9
from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
11
# Inception weights ported to Pytorch from
12
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
14
15
16
class InceptionV3(nn.Module):
17
"""Pretrained InceptionV3 network returning feature maps"""
18
19
# Index of default block of inception to return,
20
# corresponds to output of final average pooling
21
def __init__(self, resize_input=True, normalize_input=False, requires_grad=False):
22
"""Build pretrained InceptionV3
23
Parameters
24
----------
25
resize_input : bool
26
If true, bilinearly resizes input to width and height 299 before
27
feeding input to model. As the network without fully connected
28
layers is fully convolutional, it should be able to handle inputs
29
of arbitrary size, so resizing might not be strictly needed
30
normalize_input : bool
31
If true, scales the input from range (0, 1) to the range the
32
pretrained Inception network expects, namely (-1, 1)
33
requires_grad : bool
34
If true, parameters of the model require gradients. Possibly useful
35
for finetuning the network
36
"""
37
super(InceptionV3, self).__init__()
38
39
self.resize_input = resize_input
40
self.normalize_input = normalize_input
41
self.blocks = nn.ModuleList()
42
43
state_dict, inception = fid_inception_v3()
44
45
# Block 0: input to maxpool1
46
block0 = [
47
inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
48
nn.MaxPool2d(kernel_size=3, stride=2)
49
]
50
self.blocks.append(nn.Sequential(*block0))
51
52
# Block 1: maxpool1 to maxpool2
53
block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
54
self.blocks.append(nn.Sequential(*block1))
55
56
# Block 2: maxpool2 to aux classifier
57
block2 = [
58
inception.Mixed_5b,
59
inception.Mixed_5c,
60
inception.Mixed_5d,
61
inception.Mixed_6a,
62
inception.Mixed_6b,
63
inception.Mixed_6c,
64
inception.Mixed_6d,
65
inception.Mixed_6e,
66
]
67
self.blocks.append(nn.Sequential(*block2))
68
69
# Block 3: aux classifier to final avgpool
70
block3 = [inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, nn.AdaptiveAvgPool2d(output_size=(1, 1))]
71
self.blocks.append(nn.Sequential(*block3))
72
73
with torch.no_grad():
74
self.fc = nn.Linear(2048, 1008, bias=True)
75
self.fc.weight.copy_(state_dict['fc.weight'])
76
self.fc.bias.copy_(state_dict['fc.bias'])
77
78
for param in self.parameters():
79
param.requires_grad = requires_grad
80
81
def forward(self, inp):
82
"""Get Inception feature maps
83
Parameters
84
----------
85
inp : torch.autograd.Variable
86
Input tensor of shape Bx3xHxW. Values are expected to be in
87
range (0, 1)
88
Returns
89
-------
90
List of torch.autograd.Variable, corresponding to the selected output
91
block, sorted ascending by index
92
"""
93
x = inp
94
95
if self.resize_input:
96
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
97
98
if self.normalize_input:
99
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
100
101
for idx, block in enumerate(self.blocks):
102
x = block(x)
103
104
x = F.dropout(x, training=False)
105
x = torch.flatten(x, 1)
106
logit = self.fc(x)
107
return x, logit
108
109
110
def fid_inception_v3():
111
"""Build pretrained Inception model for FID computation
112
The Inception model for FID computation uses a different set of weights
113
and has a slightly different structure than torchvision's Inception.
114
This method first constructs torchvision's Inception and then patches the
115
necessary parts that are different in the FID Inception model.
116
"""
117
inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
118
119
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
120
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
121
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
122
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
123
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
124
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
125
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
126
inception.Mixed_7b = FIDInceptionE_1(1280)
127
inception.Mixed_7c = FIDInceptionE_2(2048)
128
# inception.fc = nn.Linear(2048, 1008, bias=False)
129
130
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
131
inception.load_state_dict(state_dict)
132
return state_dict, inception
133
134
135
class FIDInceptionA(models.inception.InceptionA):
136
"""InceptionA block patched for FID computation"""
137
def __init__(self, in_channels, pool_features):
138
super(FIDInceptionA, self).__init__(in_channels, pool_features)
139
140
def forward(self, x):
141
branch1x1 = self.branch1x1(x)
142
143
branch5x5 = self.branch5x5_1(x)
144
branch5x5 = self.branch5x5_2(branch5x5)
145
146
branch3x3dbl = self.branch3x3dbl_1(x)
147
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
148
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
149
150
# Patch: Tensorflow's average pool does not use the padded zero's in
151
# its average calculation
152
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
153
branch_pool = self.branch_pool(branch_pool)
154
155
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
156
return torch.cat(outputs, 1)
157
158
159
class FIDInceptionC(models.inception.InceptionC):
160
"""InceptionC block patched for FID computation"""
161
def __init__(self, in_channels, channels_7x7):
162
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
163
164
def forward(self, x):
165
branch1x1 = self.branch1x1(x)
166
167
branch7x7 = self.branch7x7_1(x)
168
branch7x7 = self.branch7x7_2(branch7x7)
169
branch7x7 = self.branch7x7_3(branch7x7)
170
171
branch7x7dbl = self.branch7x7dbl_1(x)
172
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
173
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
174
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
175
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
176
177
# Patch: Tensorflow's average pool does not use the padded zero's in
178
# its average calculation
179
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
180
branch_pool = self.branch_pool(branch_pool)
181
182
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
183
return torch.cat(outputs, 1)
184
185
186
class FIDInceptionE_1(models.inception.InceptionE):
187
"""First InceptionE block patched for FID computation"""
188
def __init__(self, in_channels):
189
super(FIDInceptionE_1, self).__init__(in_channels)
190
191
def forward(self, x):
192
branch1x1 = self.branch1x1(x)
193
194
branch3x3 = self.branch3x3_1(x)
195
branch3x3 = [
196
self.branch3x3_2a(branch3x3),
197
self.branch3x3_2b(branch3x3),
198
]
199
branch3x3 = torch.cat(branch3x3, 1)
200
201
branch3x3dbl = self.branch3x3dbl_1(x)
202
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
203
branch3x3dbl = [
204
self.branch3x3dbl_3a(branch3x3dbl),
205
self.branch3x3dbl_3b(branch3x3dbl),
206
]
207
branch3x3dbl = torch.cat(branch3x3dbl, 1)
208
209
# Patch: Tensorflow's average pool does not use the padded zero's in
210
# its average calculation
211
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
212
branch_pool = self.branch_pool(branch_pool)
213
214
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
215
return torch.cat(outputs, 1)
216
217
218
class FIDInceptionE_2(models.inception.InceptionE):
219
"""Second InceptionE block patched for FID computation"""
220
def __init__(self, in_channels):
221
super(FIDInceptionE_2, self).__init__(in_channels)
222
223
def forward(self, x):
224
branch1x1 = self.branch1x1(x)
225
226
branch3x3 = self.branch3x3_1(x)
227
branch3x3 = [
228
self.branch3x3_2a(branch3x3),
229
self.branch3x3_2b(branch3x3),
230
]
231
branch3x3 = torch.cat(branch3x3, 1)
232
233
branch3x3dbl = self.branch3x3dbl_1(x)
234
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
235
branch3x3dbl = [
236
self.branch3x3dbl_3a(branch3x3dbl),
237
self.branch3x3dbl_3b(branch3x3dbl),
238
]
239
branch3x3dbl = torch.cat(branch3x3dbl, 1)
240
241
# Patch: The FID Inception model uses max pooling instead of average
242
# pooling. This is likely an error in this specific Inception
243
# implementation, as other Inception models use average pooling here
244
# (which matches the description in the paper).
245
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
246
branch_pool = self.branch_pool(branch_pool)
247
248
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
249
return torch.cat(outputs, 1)
250
251