Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/FaceMaskOverlay/tools/train.py
3142 views
1
# ------------------------------------------------------------------------------
2
# Copyright (c) Microsoft
3
# Licensed under the MIT License.
4
# Created by Tianheng Cheng([email protected])
5
# ------------------------------------------------------------------------------
6
7
import os
8
import pprint
9
import argparse
10
11
import torch
12
import torch.nn as nn
13
import torch.optim as optim
14
import torch.backends.cudnn as cudnn
15
from tensorboardX import SummaryWriter
16
from torch.utils.data import DataLoader
17
import sys
18
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
19
import lib.models as models
20
from lib.config import config, update_config
21
from lib.datasets import get_dataset
22
from lib.core import function
23
from lib.utils import utils
24
25
26
def parse_args():
27
28
parser = argparse.ArgumentParser(description='Train Face Alignment')
29
30
parser.add_argument('--cfg', help='experiment configuration filename',
31
required=True, type=str)
32
33
args = parser.parse_args()
34
update_config(config, args)
35
return args
36
37
38
def main():
39
40
args = parse_args()
41
42
logger, final_output_dir, tb_log_dir = \
43
utils.create_logger(config, args.cfg, 'train')
44
45
logger.info(pprint.pformat(args))
46
logger.info(pprint.pformat(config))
47
48
cudnn.benchmark = config.CUDNN.BENCHMARK
49
cudnn.determinstic = config.CUDNN.DETERMINISTIC
50
cudnn.enabled = config.CUDNN.ENABLED
51
52
model = models.get_face_alignment_net(config)
53
54
# copy model files
55
writer_dict = {
56
'writer': SummaryWriter(log_dir=tb_log_dir),
57
'train_global_steps': 0,
58
'valid_global_steps': 0,
59
}
60
61
gpus = list(config.GPUS)
62
model = nn.DataParallel(model, device_ids=gpus).cuda()
63
64
# loss
65
criterion = torch.nn.MSELoss(size_average=True).cuda()
66
67
optimizer = utils.get_optimizer(config, model)
68
best_nme = 100
69
last_epoch = config.TRAIN.BEGIN_EPOCH
70
if config.TRAIN.RESUME:
71
model_state_file = os.path.join(final_output_dir,
72
'latest.pth')
73
if os.path.islink(model_state_file):
74
checkpoint = torch.load(model_state_file)
75
last_epoch = checkpoint['epoch']
76
best_nme = checkpoint['best_nme']
77
model.load_state_dict(checkpoint['state_dict'])
78
optimizer.load_state_dict(checkpoint['optimizer'])
79
print("=> loaded checkpoint (epoch {})"
80
.format(checkpoint['epoch']))
81
else:
82
print("=> no checkpoint found")
83
84
if isinstance(config.TRAIN.LR_STEP, list):
85
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
86
optimizer, config.TRAIN.LR_STEP,
87
config.TRAIN.LR_FACTOR, last_epoch-1
88
)
89
else:
90
lr_scheduler = torch.optim.lr_scheduler.StepLR(
91
optimizer, config.TRAIN.LR_STEP,
92
config.TRAIN.LR_FACTOR, last_epoch-1
93
)
94
dataset_type = get_dataset(config)
95
96
train_loader = DataLoader(
97
dataset=dataset_type(config,
98
is_train=True),
99
batch_size=config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus),
100
shuffle=config.TRAIN.SHUFFLE,
101
num_workers=config.WORKERS,
102
pin_memory=config.PIN_MEMORY)
103
104
val_loader = DataLoader(
105
dataset=dataset_type(config,
106
is_train=False),
107
batch_size=config.TEST.BATCH_SIZE_PER_GPU*len(gpus),
108
shuffle=False,
109
num_workers=config.WORKERS,
110
pin_memory=config.PIN_MEMORY
111
)
112
113
for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
114
lr_scheduler.step()
115
116
function.train(config, train_loader, model, criterion,
117
optimizer, epoch, writer_dict)
118
119
# evaluate
120
nme, predictions = function.validate(config, val_loader, model,
121
criterion, epoch, writer_dict)
122
123
is_best = nme < best_nme
124
best_nme = min(nme, best_nme)
125
126
logger.info('=> saving checkpoint to {}'.format(final_output_dir))
127
print("best:", is_best)
128
utils.save_checkpoint(
129
{"state_dict": model,
130
"epoch": epoch + 1,
131
"best_nme": best_nme,
132
"optimizer": optimizer.state_dict(),
133
}, predictions, is_best, final_output_dir, 'checkpoint_{}.pth'.format(epoch))
134
135
final_model_state_file = os.path.join(final_output_dir,
136
'final_state.pth')
137
logger.info('saving final model state to {}'.format(
138
final_model_state_file))
139
torch.save(model.module.state_dict(), final_model_state_file)
140
writer_dict['writer'].close()
141
142
143
if __name__ == '__main__':
144
main()
145
146
147
148
149
150
151
152
153
154
155
156