Path: blob/master/FBAMatting/networks/resnet_GN_WS.py
3119 views
import networks.layers_WS as L1import torch.nn as nn23__all__ = ["ResNet", "l_resnet50"]456def conv3x3(in_planes, out_planes, stride=1):7"""3x3 convolution with padding"""8return L.Conv2d(9in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False,10)111213def conv1x1(in_planes, out_planes, stride=1):14"""1x1 convolution"""15return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)161718class BasicBlock(nn.Module):19expansion = 12021def __init__(self, inplanes, planes, stride=1, downsample=None):22super(BasicBlock, self).__init__()23self.conv1 = conv3x3(inplanes, planes, stride)24self.bn1 = L.BatchNorm2d(planes)25self.relu = nn.ReLU(inplace=True)26self.conv2 = conv3x3(planes, planes)27self.bn2 = L.BatchNorm2d(planes)28self.downsample = downsample29self.stride = stride3031def forward(self, x):32identity = x3334out = self.conv1(x)35out = self.bn1(out)36out = self.relu(out)3738out = self.conv2(out)39out = self.bn2(out)4041if self.downsample is not None:42identity = self.downsample(x)4344out += identity45out = self.relu(out)4647return out484950class Bottleneck(nn.Module):51expansion = 45253def __init__(self, inplanes, planes, stride=1, downsample=None):54super(Bottleneck, self).__init__()55self.conv1 = conv1x1(inplanes, planes)56self.bn1 = L.BatchNorm2d(planes)57self.conv2 = conv3x3(planes, planes, stride)58self.bn2 = L.BatchNorm2d(planes)59self.conv3 = conv1x1(planes, planes * self.expansion)60self.bn3 = L.BatchNorm2d(planes * self.expansion)61self.relu = nn.ReLU(inplace=True)62self.downsample = downsample63self.stride = stride6465def forward(self, x):66identity = x6768out = self.conv1(x)69out = self.bn1(out)70out = self.relu(out)7172out = self.conv2(out)73out = self.bn2(out)74out = self.relu(out)7576out = self.conv3(out)77out = self.bn3(out)7879if self.downsample is not None:80identity = self.downsample(x)8182out += identity83out = self.relu(out)8485return out868788class ResNet(nn.Module):89def __init__(self, block, layers, num_classes=1000):90super(ResNet, self).__init__()91self.inplanes = 6492self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)93self.bn1 = L.BatchNorm2d(64)94self.relu = nn.ReLU(inplace=True)95self.maxpool = nn.MaxPool2d(96kernel_size=3, stride=2, padding=1, return_indices=True,97)98self.layer1 = self._make_layer(block, 64, layers[0])99self.layer2 = self._make_layer(block, 128, layers[1], stride=2)100self.layer3 = self._make_layer(block, 256, layers[2], stride=2)101self.layer4 = self._make_layer(block, 512, layers[3], stride=2)102self.avgpool = nn.AdaptiveAvgPool2d((1, 1))103self.fc = nn.Linear(512 * block.expansion, num_classes)104105def _make_layer(self, block, planes, blocks, stride=1):106downsample = None107if stride != 1 or self.inplanes != planes * block.expansion:108downsample = nn.Sequential(109conv1x1(self.inplanes, planes * block.expansion, stride),110L.BatchNorm2d(planes * block.expansion),111)112113layers = []114layers.append(block(self.inplanes, planes, stride, downsample))115self.inplanes = planes * block.expansion116for _ in range(1, blocks):117layers.append(block(self.inplanes, planes))118119return nn.Sequential(*layers)120121def forward(self, x):122x = self.conv1(x)123x = self.bn1(x)124x = self.relu(x)125x = self.maxpool(x)126127x = self.layer1(x)128x = self.layer2(x)129x = self.layer3(x)130x = self.layer4(x)131132x = self.avgpool(x)133x = x.view(x.size(0), -1)134x = self.fc(x)135136return x137138139def l_resnet50(pretrained=False, **kwargs):140"""Constructs a ResNet-50 model.141Args:142pretrained (bool): If True, returns a model pre-trained on ImageNet143"""144model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)145return model146147148