Path: blob/master/FBAMatting/networks/resnet_bn.py
3119 views
import math12import torch.nn as nn3from torch.nn import BatchNorm2d45__all__ = ["ResNet"]678def conv3x3(in_planes, out_planes, stride=1):9"3x3 convolution with padding"10return nn.Conv2d(11in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False,12)131415class BasicBlock(nn.Module):16expansion = 11718def __init__(self, inplanes, planes, stride=1, downsample=None):19super(BasicBlock, self).__init__()20self.conv1 = conv3x3(inplanes, planes, stride)21self.bn1 = BatchNorm2d(planes)22self.relu = nn.ReLU(inplace=True)23self.conv2 = conv3x3(planes, planes)24self.bn2 = BatchNorm2d(planes)25self.downsample = downsample26self.stride = stride2728def forward(self, x):29residual = x3031out = self.conv1(x)32out = self.bn1(out)33out = self.relu(out)3435out = self.conv2(out)36out = self.bn2(out)3738if self.downsample is not None:39residual = self.downsample(x)4041out += residual42out = self.relu(out)4344return out454647class Bottleneck(nn.Module):48expansion = 44950def __init__(self, inplanes, planes, stride=1, downsample=None):51super(Bottleneck, self).__init__()52self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)53self.bn1 = BatchNorm2d(planes)54self.conv2 = nn.Conv2d(55planes, planes, kernel_size=3, stride=stride, padding=1, bias=False,56)57self.bn2 = BatchNorm2d(planes, momentum=0.01)58self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)59self.bn3 = BatchNorm2d(planes * 4)60self.relu = nn.ReLU(inplace=True)61self.downsample = downsample62self.stride = stride6364def forward(self, x):65residual = x6667out = self.conv1(x)68out = self.bn1(out)69out = self.relu(out)7071out = self.conv2(out)72out = self.bn2(out)73out = self.relu(out)7475out = self.conv3(out)76out = self.bn3(out)7778if self.downsample is not None:79residual = self.downsample(x)8081out += residual82out = self.relu(out)8384return out858687class ResNet(nn.Module):88def __init__(self, block, layers, num_classes=1000):89self.inplanes = 12890super(ResNet, self).__init__()91self.conv1 = conv3x3(3, 64, stride=2)92self.bn1 = BatchNorm2d(64)93self.relu1 = nn.ReLU(inplace=True)94self.conv2 = conv3x3(64, 64)95self.bn2 = BatchNorm2d(64)96self.relu2 = nn.ReLU(inplace=True)97self.conv3 = conv3x3(64, 128)98self.bn3 = BatchNorm2d(128)99self.relu3 = nn.ReLU(inplace=True)100self.maxpool = nn.MaxPool2d(101kernel_size=3, stride=2, padding=1, return_indices=True,102)103104self.layer1 = self._make_layer(block, 64, layers[0])105self.layer2 = self._make_layer(block, 128, layers[1], stride=2)106self.layer3 = self._make_layer(block, 256, layers[2], stride=2)107self.layer4 = self._make_layer(block, 512, layers[3], stride=2)108self.avgpool = nn.AvgPool2d(7, stride=1)109self.fc = nn.Linear(512 * block.expansion, num_classes)110111for m in self.modules():112if isinstance(m, nn.Conv2d):113n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels114m.weight.data.normal_(0, math.sqrt(2.0 / n))115elif isinstance(m, BatchNorm2d):116m.weight.data.fill_(1)117m.bias.data.zero_()118119def _make_layer(self, block, planes, blocks, stride=1):120downsample = None121if stride != 1 or self.inplanes != planes * block.expansion:122downsample = nn.Sequential(123nn.Conv2d(124self.inplanes,125planes * block.expansion,126kernel_size=1,127stride=stride,128bias=False,129),130BatchNorm2d(planes * block.expansion),131)132133layers = []134layers.append(block(self.inplanes, planes, stride, downsample))135self.inplanes = planes * block.expansion136for i in range(1, blocks):137layers.append(block(self.inplanes, planes))138139return nn.Sequential(*layers)140141def forward(self, x):142x = self.relu1(self.bn1(self.conv1(x)))143x = self.relu2(self.bn2(self.conv2(x)))144x = self.relu3(self.bn3(self.conv3(x)))145x, indices = self.maxpool(x)146147x = self.layer1(x)148x = self.layer2(x)149x = self.layer3(x)150x = self.layer4(x)151152x = self.avgpool(x)153x = x.view(x.size(0), -1)154x = self.fc(x)155return x156157158def l_resnet50():159"""Constructs a ResNet-50 model.160Args:161pretrained (bool): If True, returns a model pre-trained on ImageNet162"""163model = ResNet(Bottleneck, [3, 4, 6, 3])164return model165166167