Path: blob/master/Face-Recognition-with-ArcFace/backbone.py
3118 views
# Original code1# https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_irse.py23from collections import namedtuple45import torch6import torch.nn as nn78class bottleneck_IR(nn.Module):9def __init__(self, in_channel, depth, stride):10super(bottleneck_IR, self).__init__()11if in_channel == depth:12self.shortcut_layer = nn.MaxPool2d(1, stride)13else:14self.shortcut_layer = nn.Sequential(15nn.Conv2d(in_channel, depth, (1, 1), stride, bias=False),16nn.BatchNorm2d(depth),17)18self.res_layer = nn.Sequential(19nn.BatchNorm2d(in_channel),20nn.Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),21nn.PReLU(depth),22nn.Conv2d(depth, depth, (3, 3), stride, 1, bias=False),23nn.BatchNorm2d(depth),24)2526def forward(self, x):27shortcut = self.shortcut_layer(x)28res = self.res_layer(x)29return res + shortcut303132class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])):33"""A named tuple describing a ResNet block."""343536def get_block(in_channel, depth, num_units, stride=2):37return [Bottleneck(in_channel, depth, stride)] + [38Bottleneck(depth, depth, 1) for i in range(num_units - 1)39]404142class Backbone(nn.Module):43def __init__(self, input_size):44super(Backbone, self).__init__()45assert input_size[0] in [46112,47224,48], "input_size should be [112, 112] or [224, 224]"4950blocks = [51get_block(in_channel=64, depth=64, num_units=3),52get_block(in_channel=64, depth=128, num_units=4),53get_block(in_channel=128, depth=256, num_units=14),54get_block(in_channel=256, depth=512, num_units=3),55]56unit_module = bottleneck_IR5758self.input_layer = nn.Sequential(59nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False), nn.BatchNorm2d(64), nn.PReLU(64),60)61if input_size[0] == 112:62self.output_layer = nn.Sequential(63nn.BatchNorm2d(512),64nn.Dropout(),65nn.Flatten(),66nn.Linear(512 * 7 * 7, 512),67nn.BatchNorm1d(512),68)69else:70self.output_layer = nn.Sequential(71nn.BatchNorm2d(512),72nn.Dropout(),73nn.Flatten(),74nn.Linear(512 * 14 * 14, 512),75nn.BatchNorm1d(512),76)7778modules = []79for block in blocks:80for bottleneck in block:81modules.append(82unit_module(83bottleneck.in_channel, bottleneck.depth, bottleneck.stride,84),85)86self.body = nn.Sequential(*modules)8788def forward(self, x):89x = self.input_layer(x)90x = self.body(x)91x = self.output_layer(x)92return x939495