import torch.nn as nn
import torch,math,sys
import torch.utils.model_zoo as model_zoo
from functools import partial
from ...torch_core import Module
__all__ = ['XResNet', 'xresnet18', 'xresnet34', 'xresnet50', 'xresnet101', 'xresnet152']
act_fn = nn.ReLU(inplace=True)
class Flatten(Module):
def forward(self, x): return x.view(x.size(0), -1)
def init_cnn(m):
if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
for l in m.children(): init_cnn(l)
def conv(ni, nf, ks=3, stride=1, bias=False):
return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)
def noop(x): return x
def conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True):
bn = nn.BatchNorm2d(nf)
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
layers = [conv(ni, nf, ks, stride=stride), bn]
if act: layers.append(act_fn)
return nn.Sequential(*layers)
class ResBlock(Module):
def __init__(self, expansion, ni, nh, stride=1):
nf,ni = nh*expansion,ni*expansion
layers = [conv_layer(ni, nh, 3, stride=stride),
conv_layer(nh, nf, 3, zero_bn=True, act=False)
] if expansion == 1 else [
conv_layer(ni, nh, 1),
conv_layer(nh, nh, 3, stride=stride),
conv_layer(nh, nf, 1, zero_bn=True, act=False)
]
self.convs = nn.Sequential(*layers)
self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)
self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
def forward(self, x): return act_fn(self.convs(x) + self.idconv(self.pool(x)))
def filt_sz(recep): return min(64, 2**math.floor(math.log2(recep*0.75)))
class XResNet(nn.Sequential):
def __init__(self, expansion, layers, c_in=3, c_out=1000):
stem = []
sizes = [c_in,32,32,64]
for i in range(3):
stem.append(conv_layer(sizes[i], sizes[i+1], stride=2 if i==0 else 1))
block_szs = [64//expansion,64,128,256,512]
blocks = [self._make_layer(expansion, block_szs[i], block_szs[i+1], l, 1 if i==0 else 2)
for i,l in enumerate(layers)]
super().__init__(
*stem,
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
*blocks,
nn.AdaptiveAvgPool2d(1), Flatten(),
nn.Linear(block_szs[-1]*expansion, c_out),
)
init_cnn(self)
def _make_layer(self, expansion, ni, nf, blocks, stride):
return nn.Sequential(
*[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1)
for i in range(blocks)])
def xresnet(expansion, n_layers, name, pretrained=False, **kwargs):
model = XResNet(expansion, n_layers, **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls[name]))
return model
me = sys.modules[__name__]
for n,e,l in [
[ 18 , 1, [2,2,2 ,2] ],
[ 34 , 1, [3,4,6 ,3] ],
[ 50 , 4, [3,4,6 ,3] ],
[ 101, 4, [3,4,23,3] ],
[ 152, 4, [3,8,36,3] ],
]:
name = f'xresnet{n}'
setattr(me, name, partial(xresnet, expansion=e, n_layers=l, name=name))