Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/metrics/resnet.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/metrics/resnet.py
6
7
import math
8
9
import torch.nn as nn
10
11
import utils.ops as ops
12
13
14
class BasicBlock(nn.Module):
15
expansion = 1
16
def __init__(self, inplanes, planes, stride=1, downsample=None):
17
super(BasicBlock, self).__init__()
18
self.conv1 = ops.conv3x3(inplanes, planes, stride)
19
self.bn1 = nn.BatchNorm2d(planes)
20
self.conv2 = ops.conv3x3(planes, planes)
21
self.bn2 = nn.BatchNorm2d(planes)
22
self.relu = nn.ReLU(inplace=True)
23
24
self.downsample = downsample
25
self.stride = stride
26
27
def forward(self, x):
28
residual = x
29
30
out = self.conv1(x)
31
out = self.bn1(out)
32
out = self.relu(out)
33
34
out = self.conv2(out)
35
out = self.bn2(out)
36
37
if self.downsample is not None:
38
residual = self.downsample(x)
39
40
out += residual
41
out = self.relu(out)
42
return out
43
44
45
class Bottleneck(nn.Module):
46
expansion = 4
47
def __init__(self, inplanes, planes, stride=1, downsample=None):
48
super(Bottleneck, self).__init__()
49
50
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
51
self.bn1 = nn.BatchNorm2d(planes)
52
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
53
self.bn2 = nn.BatchNorm2d(planes)
54
self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False)
55
self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion)
56
self.relu = nn.ReLU(inplace=True)
57
58
self.downsample = downsample
59
self.stride = stride
60
61
def forward(self, x):
62
residual = x
63
64
out = self.conv1(x)
65
out = self.bn1(out)
66
out = self.relu(out)
67
68
out = self.conv2(out)
69
out = self.bn2(out)
70
out = self.relu(out)
71
72
out = self.conv3(out)
73
out = self.bn3(out)
74
if self.downsample is not None:
75
residual = self.downsample(x)
76
77
out += residual
78
out = self.relu(out)
79
return out
80
81
class ResNet(nn.Module):
82
def __init__(self, dataset, depth, num_classes, bottleneck=False):
83
super(ResNet, self).__init__()
84
self.dataset = dataset
85
if self.dataset.startswith("CIFAR10"):
86
self.inplanes = 16
87
if bottleneck == True:
88
n = int((depth - 2) / 9)
89
block = Bottleneck
90
else:
91
n = int((depth - 2) / 6)
92
block = BasicBlock
93
94
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
95
self.bn1 = nn.BatchNorm2d(self.inplanes)
96
self.relu = nn.ReLU(inplace=True)
97
self.layer1 = self._make_layer(block, 16, n)
98
self.layer2 = self._make_layer(block, 32, n, stride=2)
99
self.layer3 = self._make_layer(block, 64, n, stride=2)
100
self.avgpool = nn.AvgPool2d(8)
101
self.fc = nn.Linear(64 * block.expansion, num_classes)
102
103
elif dataset == "ImageNet":
104
blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck}
105
layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}
106
assert layers[depth], "invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)"
107
108
self.inplanes = 64
109
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
110
self.bn1 = nn.BatchNorm2d(64)
111
self.relu = nn.ReLU(inplace=True)
112
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
113
self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0])
114
self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2)
115
self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2)
116
self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2)
117
self.avgpool = nn.AvgPool2d(7)
118
self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes)
119
120
for m in self.modules():
121
if isinstance(m, nn.Conv2d):
122
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
123
m.weight.data.normal_(0, math.sqrt(2. / n))
124
elif isinstance(m, nn.BatchNorm2d):
125
m.weight.data.fill_(1)
126
m.bias.data.zero_()
127
128
def _make_layer(self, block, planes, blocks, stride=1):
129
downsample = None
130
if stride != 1 or self.inplanes != planes * block.expansion:
131
downsample = nn.Sequential(
132
nn.Conv2d(self.inplanes, planes * block.expansion,
133
kernel_size=1, stride=stride, bias=False),
134
nn.BatchNorm2d(planes * block.expansion),
135
)
136
137
layers = []
138
layers.append(block(self.inplanes, planes, stride, downsample))
139
self.inplanes = planes * block.expansion
140
for i in range(1, blocks):
141
layers.append(block(self.inplanes, planes))
142
return nn.Sequential(*layers)
143
144
def forward(self, x):
145
if self.dataset == "CIFAR10" or self.dataset == "CIFAR100":
146
x = self.conv1(x)
147
x = self.bn1(x)
148
x = self.relu(x)
149
150
x = self.layer1(x)
151
x = self.layer2(x)
152
x = self.layer3(x)
153
154
x = self.avgpool(x)
155
x = x.view(x.size(0), -1)
156
x = self.fc(x)
157
158
elif self.dataset == "ImageNet" or self.dataset == "Tiny_ImageNet":
159
x = self.conv1(x)
160
x = self.bn1(x)
161
x = self.relu(x)
162
# x = self.maxpool(x)
163
164
x = self.layer1(x)
165
x = self.layer2(x)
166
x = self.layer3(x)
167
x = self.layer4(x)
168
169
x = self.avgpool(x)
170
x = x.view(x.size(0), -1)
171
x = self.fc(x)
172
return x
173
174