Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/FBAMatting/networks/resnet_GN_WS.py
3119 views
1
import networks.layers_WS as L
2
import torch.nn as nn
3
4
__all__ = ["ResNet", "l_resnet50"]
5
6
7
def conv3x3(in_planes, out_planes, stride=1):
8
"""3x3 convolution with padding"""
9
return L.Conv2d(
10
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False,
11
)
12
13
14
def conv1x1(in_planes, out_planes, stride=1):
15
"""1x1 convolution"""
16
return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
17
18
19
class BasicBlock(nn.Module):
20
expansion = 1
21
22
def __init__(self, inplanes, planes, stride=1, downsample=None):
23
super(BasicBlock, self).__init__()
24
self.conv1 = conv3x3(inplanes, planes, stride)
25
self.bn1 = L.BatchNorm2d(planes)
26
self.relu = nn.ReLU(inplace=True)
27
self.conv2 = conv3x3(planes, planes)
28
self.bn2 = L.BatchNorm2d(planes)
29
self.downsample = downsample
30
self.stride = stride
31
32
def forward(self, x):
33
identity = x
34
35
out = self.conv1(x)
36
out = self.bn1(out)
37
out = self.relu(out)
38
39
out = self.conv2(out)
40
out = self.bn2(out)
41
42
if self.downsample is not None:
43
identity = self.downsample(x)
44
45
out += identity
46
out = self.relu(out)
47
48
return out
49
50
51
class Bottleneck(nn.Module):
52
expansion = 4
53
54
def __init__(self, inplanes, planes, stride=1, downsample=None):
55
super(Bottleneck, self).__init__()
56
self.conv1 = conv1x1(inplanes, planes)
57
self.bn1 = L.BatchNorm2d(planes)
58
self.conv2 = conv3x3(planes, planes, stride)
59
self.bn2 = L.BatchNorm2d(planes)
60
self.conv3 = conv1x1(planes, planes * self.expansion)
61
self.bn3 = L.BatchNorm2d(planes * self.expansion)
62
self.relu = nn.ReLU(inplace=True)
63
self.downsample = downsample
64
self.stride = stride
65
66
def forward(self, x):
67
identity = x
68
69
out = self.conv1(x)
70
out = self.bn1(out)
71
out = self.relu(out)
72
73
out = self.conv2(out)
74
out = self.bn2(out)
75
out = self.relu(out)
76
77
out = self.conv3(out)
78
out = self.bn3(out)
79
80
if self.downsample is not None:
81
identity = self.downsample(x)
82
83
out += identity
84
out = self.relu(out)
85
86
return out
87
88
89
class ResNet(nn.Module):
90
def __init__(self, block, layers, num_classes=1000):
91
super(ResNet, self).__init__()
92
self.inplanes = 64
93
self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
94
self.bn1 = L.BatchNorm2d(64)
95
self.relu = nn.ReLU(inplace=True)
96
self.maxpool = nn.MaxPool2d(
97
kernel_size=3, stride=2, padding=1, return_indices=True,
98
)
99
self.layer1 = self._make_layer(block, 64, layers[0])
100
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
101
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
102
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
103
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
104
self.fc = nn.Linear(512 * block.expansion, num_classes)
105
106
def _make_layer(self, block, planes, blocks, stride=1):
107
downsample = None
108
if stride != 1 or self.inplanes != planes * block.expansion:
109
downsample = nn.Sequential(
110
conv1x1(self.inplanes, planes * block.expansion, stride),
111
L.BatchNorm2d(planes * block.expansion),
112
)
113
114
layers = []
115
layers.append(block(self.inplanes, planes, stride, downsample))
116
self.inplanes = planes * block.expansion
117
for _ in range(1, blocks):
118
layers.append(block(self.inplanes, planes))
119
120
return nn.Sequential(*layers)
121
122
def forward(self, x):
123
x = self.conv1(x)
124
x = self.bn1(x)
125
x = self.relu(x)
126
x = self.maxpool(x)
127
128
x = self.layer1(x)
129
x = self.layer2(x)
130
x = self.layer3(x)
131
x = self.layer4(x)
132
133
x = self.avgpool(x)
134
x = x.view(x.size(0), -1)
135
x = self.fc(x)
136
137
return x
138
139
140
def l_resnet50(pretrained=False, **kwargs):
141
"""Constructs a ResNet-50 model.
142
Args:
143
pretrained (bool): If True, returns a model pre-trained on ImageNet
144
"""
145
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
146
return model
147
148