from ...layers import *
from ...torch_core import *
__all__ = ['BasicBlock', 'WideResNet', 'wrn_22']
def _bn(ni, init_zero=False):
"Batchnorm layer with 0 initialization"
m = nn.BatchNorm2d(ni)
m.weight.data.fill_(0 if init_zero else 1)
m.bias.data.zero_()
return m
def bn_relu_conv(ni, nf, ks, stride, init_zero=False):
bn_initzero = _bn(ni, init_zero=init_zero)
return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv2d(ni, nf, ks, stride))
class BasicBlock(Module):
"Block to from a wide ResNet."
def __init__(self, ni, nf, stride, drop_p=0.0):
self.bn = nn.BatchNorm2d(ni)
self.conv1 = conv2d(ni, nf, 3, stride)
self.conv2 = bn_relu_conv(nf, nf, 3, 1)
self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None
self.shortcut = conv2d(ni, nf, 1, stride) if ni != nf else noop
def forward(self, x):
x2 = F.relu(self.bn(x), inplace=True)
r = self.shortcut(x2)
x = self.conv1(x2)
if self.drop: x = self.drop(x)
x = self.conv2(x) * 0.2
return x.add_(r)
def _make_group(N, ni, nf, block, stride, drop_p):
return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p) for i in range(N)]
class WideResNet(Module):
"Wide ResNet with `num_groups` and a width of `k`."
def __init__(self, num_groups:int, N:int, num_classes:int, k:int=1, drop_p:float=0.0, start_nf:int=16, n_in_channels:int=3):
n_channels = [start_nf]
for i in range(num_groups): n_channels.append(start_nf*(2**i)*k)
layers = [conv2d(n_in_channels, n_channels[0], 3, 1)]
for i in range(num_groups):
layers += _make_group(N, n_channels[i], n_channels[i+1], BasicBlock, (1 if i==0 else 2), drop_p)
layers += [nn.BatchNorm2d(n_channels[num_groups]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1),
Flatten(), nn.Linear(n_channels[num_groups], num_classes)]
self.features = nn.Sequential(*layers)
def forward(self, x): return self.features(x)
def wrn_22():
"Wide ResNet with 22 layers."
return WideResNet(num_groups=3, N=3, num_classes=10, k=6, drop_p=0.)