Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/vision/models/cadene_models.py
781 views
1
#These models are dowloaded via the repo https://github.com/Cadene/pretrained-models.pytorch
2
#See licence here: https://github.com/Cadene/pretrained-models.pytorch/blob/master/LICENSE.txt
3
from torch import nn
4
from ..learner import model_meta
5
from ...core import *
6
7
pretrainedmodels = try_import('pretrainedmodels')
8
if not pretrainedmodels:
9
raise Exception('Error: `pretrainedmodels` is needed. `pip install pretrainedmodels`')
10
11
__all__ = ['inceptionv4', 'inceptionresnetv2', 'nasnetamobile', 'dpn92', 'xception_cadene', 'se_resnet50',
12
'se_resnet101', 'se_resnext50_32x4d', 'senet154', 'pnasnet5large', 'se_resnext101_32x4d']
13
14
def get_model(model_name:str, pretrained:bool, seq:bool=False, pname:str='imagenet', **kwargs):
15
pretrained = pname if pretrained else None
16
model = getattr(pretrainedmodels, model_name)(pretrained=pretrained, **kwargs)
17
return nn.Sequential(*model.children()) if seq else model
18
19
def inceptionv4(pretrained:bool=False):
20
model = get_model('inceptionv4', pretrained)
21
all_layers = list(model.children())
22
return nn.Sequential(*all_layers[0], *all_layers[1:])
23
model_meta[inceptionv4] = {'cut': -2, 'split': lambda m: (m[0][11], m[1])}
24
25
def nasnetamobile(pretrained:bool=False):
26
model = get_model('nasnetamobile', pretrained, num_classes=1000)
27
model.logits = noop
28
return nn.Sequential(model)
29
model_meta[nasnetamobile] = {'cut': noop, 'split': lambda m: (list(m[0][0].children())[8], m[1])}
30
31
def pnasnet5large(pretrained:bool=False):
32
model = get_model('pnasnet5large', pretrained, num_classes=1000)
33
model.logits = noop
34
return nn.Sequential(model)
35
model_meta[pnasnet5large] = {'cut': noop, 'split': lambda m: (list(m[0][0].children())[8], m[1])}
36
37
def inceptionresnetv2(pretrained:bool=False): return get_model('inceptionresnetv2', pretrained, seq=True)
38
def dpn92(pretrained:bool=False): return get_model('dpn92', pretrained, pname='imagenet+5k', seq=True)
39
def xception_cadene(pretrained=False): return get_model('xception', pretrained, seq=True)
40
def se_resnet50(pretrained:bool=False): return get_model('se_resnet50', pretrained)
41
def se_resnet101(pretrained:bool=False): return get_model('se_resnet101', pretrained)
42
def se_resnext50_32x4d(pretrained:bool=False): return get_model('se_resnext50_32x4d', pretrained)
43
def se_resnext101_32x4d(pretrained:bool=False): return get_model('se_resnext101_32x4d', pretrained)
44
def senet154(pretrained:bool=False): return get_model('senet154', pretrained)
45
46
model_meta[inceptionresnetv2] = {'cut': -2, 'split': lambda m: (m[0][9], m[1])}
47
model_meta[dpn92] = {'cut': -1, 'split': lambda m: (m[0][0][16], m[1])}
48
model_meta[xception_cadene] = {'cut': -1, 'split': lambda m: (m[0][11], m[1])}
49
model_meta[senet154] = {'cut': -3, 'split': lambda m: (m[0][3], m[1])}
50
_se_resnet_meta = {'cut': -2, 'split': lambda m: (m[0][3], m[1])}
51
model_meta[se_resnet50] = _se_resnet_meta
52
model_meta[se_resnet101] = _se_resnet_meta
53
model_meta[se_resnext50_32x4d] = _se_resnet_meta
54
model_meta[se_resnext101_32x4d] = _se_resnet_meta
55
56
# TODO: add "resnext101_32x4d" "resnext101_64x4d" after serialization issue is fixed:
57
# https://github.com/Cadene/pretrained-models.pytorch/pull/128
58
59