from torch import nn
from ..learner import model_meta
from ...core import *
pretrainedmodels = try_import('pretrainedmodels')
if not pretrainedmodels:
raise Exception('Error: `pretrainedmodels` is needed. `pip install pretrainedmodels`')
__all__ = ['inceptionv4', 'inceptionresnetv2', 'nasnetamobile', 'dpn92', 'xception_cadene', 'se_resnet50',
'se_resnet101', 'se_resnext50_32x4d', 'senet154', 'pnasnet5large', 'se_resnext101_32x4d']
def get_model(model_name:str, pretrained:bool, seq:bool=False, pname:str='imagenet', **kwargs):
pretrained = pname if pretrained else None
model = getattr(pretrainedmodels, model_name)(pretrained=pretrained, **kwargs)
return nn.Sequential(*model.children()) if seq else model
def inceptionv4(pretrained:bool=False):
model = get_model('inceptionv4', pretrained)
all_layers = list(model.children())
return nn.Sequential(*all_layers[0], *all_layers[1:])
model_meta[inceptionv4] = {'cut': -2, 'split': lambda m: (m[0][11], m[1])}
def nasnetamobile(pretrained:bool=False):
model = get_model('nasnetamobile', pretrained, num_classes=1000)
model.logits = noop
return nn.Sequential(model)
model_meta[nasnetamobile] = {'cut': noop, 'split': lambda m: (list(m[0][0].children())[8], m[1])}
def pnasnet5large(pretrained:bool=False):
model = get_model('pnasnet5large', pretrained, num_classes=1000)
model.logits = noop
return nn.Sequential(model)
model_meta[pnasnet5large] = {'cut': noop, 'split': lambda m: (list(m[0][0].children())[8], m[1])}
def inceptionresnetv2(pretrained:bool=False): return get_model('inceptionresnetv2', pretrained, seq=True)
def dpn92(pretrained:bool=False): return get_model('dpn92', pretrained, pname='imagenet+5k', seq=True)
def xception_cadene(pretrained=False): return get_model('xception', pretrained, seq=True)
def se_resnet50(pretrained:bool=False): return get_model('se_resnet50', pretrained)
def se_resnet101(pretrained:bool=False): return get_model('se_resnet101', pretrained)
def se_resnext50_32x4d(pretrained:bool=False): return get_model('se_resnext50_32x4d', pretrained)
def se_resnext101_32x4d(pretrained:bool=False): return get_model('se_resnext101_32x4d', pretrained)
def senet154(pretrained:bool=False): return get_model('senet154', pretrained)
model_meta[inceptionresnetv2] = {'cut': -2, 'split': lambda m: (m[0][9], m[1])}
model_meta[dpn92] = {'cut': -1, 'split': lambda m: (m[0][0][16], m[1])}
model_meta[xception_cadene] = {'cut': -1, 'split': lambda m: (m[0][11], m[1])}
model_meta[senet154] = {'cut': -3, 'split': lambda m: (m[0][3], m[1])}
_se_resnet_meta = {'cut': -2, 'split': lambda m: (m[0][3], m[1])}
model_meta[se_resnet50] = _se_resnet_meta
model_meta[se_resnet101] = _se_resnet_meta
model_meta[se_resnext50_32x4d] = _se_resnet_meta
model_meta[se_resnext101_32x4d] = _se_resnet_meta