Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/vision/models/xresnet.py
781 views
1
import torch.nn as nn
2
import torch,math,sys
3
import torch.utils.model_zoo as model_zoo
4
from functools import partial
5
from ...torch_core import Module
6
7
__all__ = ['XResNet', 'xresnet18', 'xresnet34', 'xresnet50', 'xresnet101', 'xresnet152']
8
9
# or: ELU+init (a=0.54; gain=1.55)
10
act_fn = nn.ReLU(inplace=True)
11
12
class Flatten(Module):
13
def forward(self, x): return x.view(x.size(0), -1)
14
15
def init_cnn(m):
16
if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
17
if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
18
for l in m.children(): init_cnn(l)
19
20
def conv(ni, nf, ks=3, stride=1, bias=False):
21
return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)
22
23
def noop(x): return x
24
25
def conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True):
26
bn = nn.BatchNorm2d(nf)
27
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
28
layers = [conv(ni, nf, ks, stride=stride), bn]
29
if act: layers.append(act_fn)
30
return nn.Sequential(*layers)
31
32
class ResBlock(Module):
33
def __init__(self, expansion, ni, nh, stride=1):
34
nf,ni = nh*expansion,ni*expansion
35
layers = [conv_layer(ni, nh, 3, stride=stride),
36
conv_layer(nh, nf, 3, zero_bn=True, act=False)
37
] if expansion == 1 else [
38
conv_layer(ni, nh, 1),
39
conv_layer(nh, nh, 3, stride=stride),
40
conv_layer(nh, nf, 1, zero_bn=True, act=False)
41
]
42
self.convs = nn.Sequential(*layers)
43
# TODO: check whether act=True works better
44
self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)
45
self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
46
47
def forward(self, x): return act_fn(self.convs(x) + self.idconv(self.pool(x)))
48
49
def filt_sz(recep): return min(64, 2**math.floor(math.log2(recep*0.75)))
50
51
class XResNet(nn.Sequential):
52
def __init__(self, expansion, layers, c_in=3, c_out=1000):
53
stem = []
54
sizes = [c_in,32,32,64]
55
for i in range(3):
56
stem.append(conv_layer(sizes[i], sizes[i+1], stride=2 if i==0 else 1))
57
#nf = filt_sz(c_in*9)
58
#stem.append(conv_layer(c_in, nf, stride=2 if i==1 else 1))
59
#c_in = nf
60
61
block_szs = [64//expansion,64,128,256,512]
62
blocks = [self._make_layer(expansion, block_szs[i], block_szs[i+1], l, 1 if i==0 else 2)
63
for i,l in enumerate(layers)]
64
super().__init__(
65
*stem,
66
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
67
*blocks,
68
nn.AdaptiveAvgPool2d(1), Flatten(),
69
nn.Linear(block_szs[-1]*expansion, c_out),
70
)
71
init_cnn(self)
72
73
def _make_layer(self, expansion, ni, nf, blocks, stride):
74
return nn.Sequential(
75
*[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1)
76
for i in range(blocks)])
77
78
def xresnet(expansion, n_layers, name, pretrained=False, **kwargs):
79
model = XResNet(expansion, n_layers, **kwargs)
80
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls[name]))
81
return model
82
83
me = sys.modules[__name__]
84
for n,e,l in [
85
[ 18 , 1, [2,2,2 ,2] ],
86
[ 34 , 1, [3,4,6 ,3] ],
87
[ 50 , 4, [3,4,6 ,3] ],
88
[ 101, 4, [3,4,23,3] ],
89
[ 152, 4, [3,8,36,3] ],
90
]:
91
name = f'xresnet{n}'
92
setattr(me, name, partial(xresnet, expansion=e, n_layers=l, name=name))
93
94
95