Path: blob/master/src/metrics/inception_net.py
809 views
from torchvision import models1import torch2import torch.nn as nn3import torch.nn.functional as F45try:6from torchvision.models.utils import load_state_dict_from_url7except ImportError:8from torch.utils.model_zoo import load_url as load_state_dict_from_url910# Inception weights ported to Pytorch from11# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz12FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'131415class InceptionV3(nn.Module):16"""Pretrained InceptionV3 network returning feature maps"""1718# Index of default block of inception to return,19# corresponds to output of final average pooling20def __init__(self, resize_input=True, normalize_input=False, requires_grad=False):21"""Build pretrained InceptionV322Parameters23----------24resize_input : bool25If true, bilinearly resizes input to width and height 299 before26feeding input to model. As the network without fully connected27layers is fully convolutional, it should be able to handle inputs28of arbitrary size, so resizing might not be strictly needed29normalize_input : bool30If true, scales the input from range (0, 1) to the range the31pretrained Inception network expects, namely (-1, 1)32requires_grad : bool33If true, parameters of the model require gradients. Possibly useful34for finetuning the network35"""36super(InceptionV3, self).__init__()3738self.resize_input = resize_input39self.normalize_input = normalize_input40self.blocks = nn.ModuleList()4142state_dict, inception = fid_inception_v3()4344# Block 0: input to maxpool145block0 = [46inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,47nn.MaxPool2d(kernel_size=3, stride=2)48]49self.blocks.append(nn.Sequential(*block0))5051# Block 1: maxpool1 to maxpool252block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]53self.blocks.append(nn.Sequential(*block1))5455# Block 2: maxpool2 to aux classifier56block2 = [57inception.Mixed_5b,58inception.Mixed_5c,59inception.Mixed_5d,60inception.Mixed_6a,61inception.Mixed_6b,62inception.Mixed_6c,63inception.Mixed_6d,64inception.Mixed_6e,65]66self.blocks.append(nn.Sequential(*block2))6768# Block 3: aux classifier to final avgpool69block3 = [inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, nn.AdaptiveAvgPool2d(output_size=(1, 1))]70self.blocks.append(nn.Sequential(*block3))7172with torch.no_grad():73self.fc = nn.Linear(2048, 1008, bias=True)74self.fc.weight.copy_(state_dict['fc.weight'])75self.fc.bias.copy_(state_dict['fc.bias'])7677for param in self.parameters():78param.requires_grad = requires_grad7980def forward(self, inp):81"""Get Inception feature maps82Parameters83----------84inp : torch.autograd.Variable85Input tensor of shape Bx3xHxW. Values are expected to be in86range (0, 1)87Returns88-------89List of torch.autograd.Variable, corresponding to the selected output90block, sorted ascending by index91"""92x = inp9394if self.resize_input:95x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)9697if self.normalize_input:98x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)99100for idx, block in enumerate(self.blocks):101x = block(x)102103x = F.dropout(x, training=False)104x = torch.flatten(x, 1)105logit = self.fc(x)106return x, logit107108109def fid_inception_v3():110"""Build pretrained Inception model for FID computation111The Inception model for FID computation uses a different set of weights112and has a slightly different structure than torchvision's Inception.113This method first constructs torchvision's Inception and then patches the114necessary parts that are different in the FID Inception model.115"""116inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False)117118inception.Mixed_5b = FIDInceptionA(192, pool_features=32)119inception.Mixed_5c = FIDInceptionA(256, pool_features=64)120inception.Mixed_5d = FIDInceptionA(288, pool_features=64)121inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)122inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)123inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)124inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)125inception.Mixed_7b = FIDInceptionE_1(1280)126inception.Mixed_7c = FIDInceptionE_2(2048)127# inception.fc = nn.Linear(2048, 1008, bias=False)128129state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)130inception.load_state_dict(state_dict)131return state_dict, inception132133134class FIDInceptionA(models.inception.InceptionA):135"""InceptionA block patched for FID computation"""136def __init__(self, in_channels, pool_features):137super(FIDInceptionA, self).__init__(in_channels, pool_features)138139def forward(self, x):140branch1x1 = self.branch1x1(x)141142branch5x5 = self.branch5x5_1(x)143branch5x5 = self.branch5x5_2(branch5x5)144145branch3x3dbl = self.branch3x3dbl_1(x)146branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)147branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)148149# Patch: Tensorflow's average pool does not use the padded zero's in150# its average calculation151branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)152branch_pool = self.branch_pool(branch_pool)153154outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]155return torch.cat(outputs, 1)156157158class FIDInceptionC(models.inception.InceptionC):159"""InceptionC block patched for FID computation"""160def __init__(self, in_channels, channels_7x7):161super(FIDInceptionC, self).__init__(in_channels, channels_7x7)162163def forward(self, x):164branch1x1 = self.branch1x1(x)165166branch7x7 = self.branch7x7_1(x)167branch7x7 = self.branch7x7_2(branch7x7)168branch7x7 = self.branch7x7_3(branch7x7)169170branch7x7dbl = self.branch7x7dbl_1(x)171branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)172branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)173branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)174branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)175176# Patch: Tensorflow's average pool does not use the padded zero's in177# its average calculation178branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)179branch_pool = self.branch_pool(branch_pool)180181outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]182return torch.cat(outputs, 1)183184185class FIDInceptionE_1(models.inception.InceptionE):186"""First InceptionE block patched for FID computation"""187def __init__(self, in_channels):188super(FIDInceptionE_1, self).__init__(in_channels)189190def forward(self, x):191branch1x1 = self.branch1x1(x)192193branch3x3 = self.branch3x3_1(x)194branch3x3 = [195self.branch3x3_2a(branch3x3),196self.branch3x3_2b(branch3x3),197]198branch3x3 = torch.cat(branch3x3, 1)199200branch3x3dbl = self.branch3x3dbl_1(x)201branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)202branch3x3dbl = [203self.branch3x3dbl_3a(branch3x3dbl),204self.branch3x3dbl_3b(branch3x3dbl),205]206branch3x3dbl = torch.cat(branch3x3dbl, 1)207208# Patch: Tensorflow's average pool does not use the padded zero's in209# its average calculation210branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)211branch_pool = self.branch_pool(branch_pool)212213outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]214return torch.cat(outputs, 1)215216217class FIDInceptionE_2(models.inception.InceptionE):218"""Second InceptionE block patched for FID computation"""219def __init__(self, in_channels):220super(FIDInceptionE_2, self).__init__(in_channels)221222def forward(self, x):223branch1x1 = self.branch1x1(x)224225branch3x3 = self.branch3x3_1(x)226branch3x3 = [227self.branch3x3_2a(branch3x3),228self.branch3x3_2b(branch3x3),229]230branch3x3 = torch.cat(branch3x3, 1)231232branch3x3dbl = self.branch3x3dbl_1(x)233branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)234branch3x3dbl = [235self.branch3x3dbl_3a(branch3x3dbl),236self.branch3x3dbl_3b(branch3x3dbl),237]238branch3x3dbl = torch.cat(branch3x3dbl, 1)239240# Patch: The FID Inception model uses max pooling instead of average241# pooling. This is likely an error in this specific Inception242# implementation, as other Inception models use average pooling here243# (which matches the description in the paper).244branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)245branch_pool = self.branch_pool(branch_pool)246247outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]248return torch.cat(outputs, 1)249250251