Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/vision/models/wrn.py
781 views
1
from ...layers import *
2
from ...torch_core import *
3
4
__all__ = ['BasicBlock', 'WideResNet', 'wrn_22']
5
6
def _bn(ni, init_zero=False):
7
"Batchnorm layer with 0 initialization"
8
m = nn.BatchNorm2d(ni)
9
m.weight.data.fill_(0 if init_zero else 1)
10
m.bias.data.zero_()
11
return m
12
13
def bn_relu_conv(ni, nf, ks, stride, init_zero=False):
14
bn_initzero = _bn(ni, init_zero=init_zero)
15
return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv2d(ni, nf, ks, stride))
16
17
class BasicBlock(Module):
18
"Block to from a wide ResNet."
19
def __init__(self, ni, nf, stride, drop_p=0.0):
20
self.bn = nn.BatchNorm2d(ni)
21
self.conv1 = conv2d(ni, nf, 3, stride)
22
self.conv2 = bn_relu_conv(nf, nf, 3, 1)
23
self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None
24
self.shortcut = conv2d(ni, nf, 1, stride) if ni != nf else noop
25
26
def forward(self, x):
27
x2 = F.relu(self.bn(x), inplace=True)
28
r = self.shortcut(x2)
29
x = self.conv1(x2)
30
if self.drop: x = self.drop(x)
31
x = self.conv2(x) * 0.2
32
return x.add_(r)
33
34
def _make_group(N, ni, nf, block, stride, drop_p):
35
return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p) for i in range(N)]
36
37
class WideResNet(Module):
38
"Wide ResNet with `num_groups` and a width of `k`."
39
def __init__(self, num_groups:int, N:int, num_classes:int, k:int=1, drop_p:float=0.0, start_nf:int=16, n_in_channels:int=3):
40
n_channels = [start_nf]
41
for i in range(num_groups): n_channels.append(start_nf*(2**i)*k)
42
43
layers = [conv2d(n_in_channels, n_channels[0], 3, 1)] # conv1
44
for i in range(num_groups):
45
layers += _make_group(N, n_channels[i], n_channels[i+1], BasicBlock, (1 if i==0 else 2), drop_p)
46
47
layers += [nn.BatchNorm2d(n_channels[num_groups]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1),
48
Flatten(), nn.Linear(n_channels[num_groups], num_classes)]
49
self.features = nn.Sequential(*layers)
50
51
def forward(self, x): return self.features(x)
52
53
54
def wrn_22():
55
"Wide ResNet with 22 layers."
56
return WideResNet(num_groups=3, N=3, num_classes=10, k=6, drop_p=0.)
57
58