Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/vision/models/xception.py
781 views
1
from ...vision import *
2
3
__all__ = ['xception']
4
5
def sep_conv(ni,nf,pad=None,pool=False,act=True):
6
layers = [nn.ReLU()] if act else []
7
layers += [
8
nn.Conv2d(ni,ni,3,1,1,groups=ni,bias=False),
9
nn.Conv2d(ni,nf,1,bias=False),
10
nn.BatchNorm2d(nf)
11
]
12
if pool: layers.append(nn.MaxPool2d(2))
13
return nn.Sequential(*layers)
14
15
def conv(ni,nf,ks=1,stride=1, pad=None, act=True):
16
if pad is None: pad=ks//2
17
layers = [
18
nn.Conv2d(ni,nf,ks,stride,pad,bias=False),
19
nn.BatchNorm2d(nf),
20
]
21
if act: layers.append(nn.ReLU())
22
return nn.Sequential(*layers)
23
24
class ConvSkip(Module):
25
def __init__(self,ni,nf=None,act=True):
26
self.nf,self.ni = nf,ni
27
if self.nf is None: self.nf = ni
28
self.conv = conv(ni,nf,stride=2, act=False)
29
self.m = nn.Sequential(
30
sep_conv(ni,ni,act=act),
31
sep_conv(ni,nf,pool=True)
32
)
33
34
def forward(self,x): return self.conv(x) + self.m(x)
35
36
def middle_flow(nf):
37
layers = [sep_conv(nf,nf) for i in range(3)]
38
return SequentialEx(*layers, MergeLayer())
39
40
def xception(c, k=8, n_middle=8):
41
"Preview version of Xception network. Not tested yet - use at own risk. No pretrained model yet."
42
layers = [
43
conv(3, k*4, 3, 2),
44
conv(k*4, k*8, 3),
45
ConvSkip(k*8, k*16, act=False),
46
ConvSkip(k*16, k*32),
47
ConvSkip(k*32, k*91),
48
]
49
for i in range(n_middle): layers.append(middle_flow(k*91))
50
layers += [
51
ConvSkip(k*91,k*128),
52
sep_conv(k*128,k*192,act=False),
53
sep_conv(k*192,k*256),
54
nn.ReLU(),
55
nn.AdaptiveAvgPool2d(1),
56
Flatten(),
57
nn.Linear(k*256,c)
58
]
59
return nn.Sequential(*layers)
60
61
62