Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/FBAMatting/networks/resnet_bn.py
3119 views
1
import math
2
3
import torch.nn as nn
4
from torch.nn import BatchNorm2d
5
6
__all__ = ["ResNet"]
7
8
9
def conv3x3(in_planes, out_planes, stride=1):
10
"3x3 convolution with padding"
11
return nn.Conv2d(
12
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False,
13
)
14
15
16
class BasicBlock(nn.Module):
17
expansion = 1
18
19
def __init__(self, inplanes, planes, stride=1, downsample=None):
20
super(BasicBlock, self).__init__()
21
self.conv1 = conv3x3(inplanes, planes, stride)
22
self.bn1 = BatchNorm2d(planes)
23
self.relu = nn.ReLU(inplace=True)
24
self.conv2 = conv3x3(planes, planes)
25
self.bn2 = BatchNorm2d(planes)
26
self.downsample = downsample
27
self.stride = stride
28
29
def forward(self, x):
30
residual = x
31
32
out = self.conv1(x)
33
out = self.bn1(out)
34
out = self.relu(out)
35
36
out = self.conv2(out)
37
out = self.bn2(out)
38
39
if self.downsample is not None:
40
residual = self.downsample(x)
41
42
out += residual
43
out = self.relu(out)
44
45
return out
46
47
48
class Bottleneck(nn.Module):
49
expansion = 4
50
51
def __init__(self, inplanes, planes, stride=1, downsample=None):
52
super(Bottleneck, self).__init__()
53
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
54
self.bn1 = BatchNorm2d(planes)
55
self.conv2 = nn.Conv2d(
56
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False,
57
)
58
self.bn2 = BatchNorm2d(planes, momentum=0.01)
59
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
60
self.bn3 = BatchNorm2d(planes * 4)
61
self.relu = nn.ReLU(inplace=True)
62
self.downsample = downsample
63
self.stride = stride
64
65
def forward(self, x):
66
residual = x
67
68
out = self.conv1(x)
69
out = self.bn1(out)
70
out = self.relu(out)
71
72
out = self.conv2(out)
73
out = self.bn2(out)
74
out = self.relu(out)
75
76
out = self.conv3(out)
77
out = self.bn3(out)
78
79
if self.downsample is not None:
80
residual = self.downsample(x)
81
82
out += residual
83
out = self.relu(out)
84
85
return out
86
87
88
class ResNet(nn.Module):
89
def __init__(self, block, layers, num_classes=1000):
90
self.inplanes = 128
91
super(ResNet, self).__init__()
92
self.conv1 = conv3x3(3, 64, stride=2)
93
self.bn1 = BatchNorm2d(64)
94
self.relu1 = nn.ReLU(inplace=True)
95
self.conv2 = conv3x3(64, 64)
96
self.bn2 = BatchNorm2d(64)
97
self.relu2 = nn.ReLU(inplace=True)
98
self.conv3 = conv3x3(64, 128)
99
self.bn3 = BatchNorm2d(128)
100
self.relu3 = nn.ReLU(inplace=True)
101
self.maxpool = nn.MaxPool2d(
102
kernel_size=3, stride=2, padding=1, return_indices=True,
103
)
104
105
self.layer1 = self._make_layer(block, 64, layers[0])
106
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
107
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
108
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
109
self.avgpool = nn.AvgPool2d(7, stride=1)
110
self.fc = nn.Linear(512 * block.expansion, num_classes)
111
112
for m in self.modules():
113
if isinstance(m, nn.Conv2d):
114
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
115
m.weight.data.normal_(0, math.sqrt(2.0 / n))
116
elif isinstance(m, BatchNorm2d):
117
m.weight.data.fill_(1)
118
m.bias.data.zero_()
119
120
def _make_layer(self, block, planes, blocks, stride=1):
121
downsample = None
122
if stride != 1 or self.inplanes != planes * block.expansion:
123
downsample = nn.Sequential(
124
nn.Conv2d(
125
self.inplanes,
126
planes * block.expansion,
127
kernel_size=1,
128
stride=stride,
129
bias=False,
130
),
131
BatchNorm2d(planes * block.expansion),
132
)
133
134
layers = []
135
layers.append(block(self.inplanes, planes, stride, downsample))
136
self.inplanes = planes * block.expansion
137
for i in range(1, blocks):
138
layers.append(block(self.inplanes, planes))
139
140
return nn.Sequential(*layers)
141
142
def forward(self, x):
143
x = self.relu1(self.bn1(self.conv1(x)))
144
x = self.relu2(self.bn2(self.conv2(x)))
145
x = self.relu3(self.bn3(self.conv3(x)))
146
x, indices = self.maxpool(x)
147
148
x = self.layer1(x)
149
x = self.layer2(x)
150
x = self.layer3(x)
151
x = self.layer4(x)
152
153
x = self.avgpool(x)
154
x = x.view(x.size(0), -1)
155
x = self.fc(x)
156
return x
157
158
159
def l_resnet50():
160
"""Constructs a ResNet-50 model.
161
Args:
162
pretrained (bool): If True, returns a model pre-trained on ImageNet
163
"""
164
model = ResNet(Bottleneck, [3, 4, 6, 3])
165
return model
166
167