Path: blob/master/Face-Recognition-with-ArcFace/align/get_nets.py
3142 views
import torch1import torch.nn as nn2import torch.nn.functional as F3from collections import OrderedDict4import numpy as np567class Flatten(nn.Module):89def __init__(self):10super(Flatten, self).__init__()1112def forward(self, x):13"""14Arguments:15x: a float tensor with shape [batch_size, c, h, w].16Returns:17a float tensor with shape [batch_size, c*h*w].18"""1920# without this pretrained model isn't working21x = x.transpose(3, 2).contiguous()2223return x.view(x.size(0), -1)242526class PNet(nn.Module):2728def __init__(self):2930super(PNet, self).__init__()3132# suppose we have input with size HxW, then33# after first layer: H - 2,34# after pool: ceil((H - 2)/2),35# after second conv: ceil((H - 2)/2) - 2,36# after last conv: ceil((H - 2)/2) - 4,37# and the same for W3839self.features = nn.Sequential(OrderedDict([40('conv1', nn.Conv2d(3, 10, 3, 1)),41('prelu1', nn.PReLU(10)),42('pool1', nn.MaxPool2d(2, 2, ceil_mode = True)),4344('conv2', nn.Conv2d(10, 16, 3, 1)),45('prelu2', nn.PReLU(16)),4647('conv3', nn.Conv2d(16, 32, 3, 1)),48('prelu3', nn.PReLU(32))49]))5051self.conv4_1 = nn.Conv2d(32, 2, 1, 1)52self.conv4_2 = nn.Conv2d(32, 4, 1, 1)5354weights = np.load("align/pnet.npy", allow_pickle=True)[()]55for n, p in self.named_parameters():56p.data = torch.FloatTensor(weights[n])5758def forward(self, x):59"""60Arguments:61x: a float tensor with shape [batch_size, 3, h, w].62Returns:63b: a float tensor with shape [batch_size, 4, h', w'].64a: a float tensor with shape [batch_size, 2, h', w'].65"""66x = self.features(x)67a = self.conv4_1(x)68b = self.conv4_2(x)69a = F.softmax(a)70return b, a717273class RNet(nn.Module):7475def __init__(self):7677super(RNet, self).__init__()7879self.features = nn.Sequential(OrderedDict([80('conv1', nn.Conv2d(3, 28, 3, 1)),81('prelu1', nn.PReLU(28)),82('pool1', nn.MaxPool2d(3, 2, ceil_mode = True)),8384('conv2', nn.Conv2d(28, 48, 3, 1)),85('prelu2', nn.PReLU(48)),86('pool2', nn.MaxPool2d(3, 2, ceil_mode = True)),8788('conv3', nn.Conv2d(48, 64, 2, 1)),89('prelu3', nn.PReLU(64)),9091('flatten', Flatten()),92('conv4', nn.Linear(576, 128)),93('prelu4', nn.PReLU(128))94]))9596self.conv5_1 = nn.Linear(128, 2)97self.conv5_2 = nn.Linear(128, 4)9899weights = np.load("align/rnet.npy", allow_pickle=True)[()]100for n, p in self.named_parameters():101p.data = torch.FloatTensor(weights[n])102103def forward(self, x):104"""105Arguments:106x: a float tensor with shape [batch_size, 3, h, w].107Returns:108b: a float tensor with shape [batch_size, 4].109a: a float tensor with shape [batch_size, 2].110"""111x = self.features(x)112a = self.conv5_1(x)113b = self.conv5_2(x)114a = F.softmax(a)115return b, a116117118class ONet(nn.Module):119120def __init__(self):121122super(ONet, self).__init__()123124self.features = nn.Sequential(OrderedDict([125('conv1', nn.Conv2d(3, 32, 3, 1)),126('prelu1', nn.PReLU(32)),127('pool1', nn.MaxPool2d(3, 2, ceil_mode = True)),128129('conv2', nn.Conv2d(32, 64, 3, 1)),130('prelu2', nn.PReLU(64)),131('pool2', nn.MaxPool2d(3, 2, ceil_mode = True)),132133('conv3', nn.Conv2d(64, 64, 3, 1)),134('prelu3', nn.PReLU(64)),135('pool3', nn.MaxPool2d(2, 2, ceil_mode = True)),136137('conv4', nn.Conv2d(64, 128, 2, 1)),138('prelu4', nn.PReLU(128)),139140('flatten', Flatten()),141('conv5', nn.Linear(1152, 256)),142('drop5', nn.Dropout(0.25)),143('prelu5', nn.PReLU(256)),144]))145146self.conv6_1 = nn.Linear(256, 2)147self.conv6_2 = nn.Linear(256, 4)148self.conv6_3 = nn.Linear(256, 10)149150weights = np.load("align/onet.npy", allow_pickle=True)[()]151for n, p in self.named_parameters():152p.data = torch.FloatTensor(weights[n])153154def forward(self, x):155"""156Arguments:157x: a float tensor with shape [batch_size, 3, h, w].158Returns:159c: a float tensor with shape [batch_size, 10].160b: a float tensor with shape [batch_size, 4].161a: a float tensor with shape [batch_size, 2].162"""163x = self.features(x)164a = self.conv6_1(x)165b = self.conv6_2(x)166c = self.conv6_3(x)167a = F.softmax(a)168return c, b, a169170