Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/FaceMaskOverlay/lib/utils/utils.py
3443 views
1
# ------------------------------------------------------------------------------
2
# Copyright (c) Microsoft
3
# Licensed under the MIT License.
4
# Written by Bin Xiao ([email protected])
5
# Modified by Ke Sun ([email protected]), Tianheng Cheng([email protected])
6
# ------------------------------------------------------------------------------
7
8
from __future__ import absolute_import
9
from __future__ import division
10
from __future__ import print_function
11
12
import os
13
import logging
14
import time
15
from pathlib import Path
16
17
import torch
18
import torch.optim as optim
19
20
21
def create_logger(cfg, cfg_name, phase='train'):
22
root_output_dir = Path(cfg.OUTPUT_DIR)
23
# set up logger
24
if not root_output_dir.exists():
25
print('=> creating {}'.format(root_output_dir))
26
root_output_dir.mkdir()
27
28
dataset = cfg.DATASET.DATASET
29
model = cfg.MODEL.NAME
30
cfg_name = os.path.basename(cfg_name).split('.')[0]
31
32
final_output_dir = root_output_dir / dataset / cfg_name
33
34
print('=> creating {}'.format(final_output_dir))
35
final_output_dir.mkdir(parents=True, exist_ok=True)
36
37
time_str = time.strftime('%Y-%m-%d-%H-%M')
38
log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase)
39
final_log_file = final_output_dir / log_file
40
head = '%(asctime)-15s %(message)s'
41
logging.basicConfig(filename=str(final_log_file),
42
format=head)
43
logger = logging.getLogger()
44
logger.setLevel(logging.INFO)
45
console = logging.StreamHandler()
46
logging.getLogger('').addHandler(console)
47
48
tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \
49
(cfg_name + '_' + time_str)
50
print('=> creating {}'.format(tensorboard_log_dir))
51
tensorboard_log_dir.mkdir(parents=True, exist_ok=True)
52
53
return logger, str(final_output_dir), str(tensorboard_log_dir)
54
55
56
def get_optimizer(cfg, model):
57
optimizer = None
58
if cfg.TRAIN.OPTIMIZER == 'sgd':
59
optimizer = optim.SGD(
60
filter(lambda p: p.requires_grad, model.parameters()),
61
lr=cfg.TRAIN.LR,
62
momentum=cfg.TRAIN.MOMENTUM,
63
weight_decay=cfg.TRAIN.WD,
64
nesterov=cfg.TRAIN.NESTEROV
65
)
66
elif cfg.TRAIN.OPTIMIZER == 'adam':
67
optimizer = optim.Adam(
68
filter(lambda p: p.requires_grad, model.parameters()),
69
lr=cfg.TRAIN.LR
70
)
71
elif cfg.TRAIN.OPTIMIZER == 'rmsprop':
72
optimizer = optim.RMSprop(
73
filter(lambda p: p.requires_grad, model.parameters()),
74
lr=cfg.TRAIN.LR,
75
momentum=cfg.TRAIN.MOMENTUM,
76
weight_decay=cfg.TRAIN.WD,
77
alpha=cfg.TRAIN.RMSPROP_ALPHA,
78
centered=cfg.TRAIN.RMSPROP_CENTERED
79
)
80
81
return optimizer
82
83
84
def save_checkpoint(states, predictions, is_best,
85
output_dir, filename='checkpoint.pth'):
86
preds = predictions.cpu().data.numpy()
87
torch.save(states, os.path.join(output_dir, filename))
88
torch.save(preds, os.path.join(output_dir, 'current_pred.pth'))
89
90
latest_path = os.path.join(output_dir, 'latest.pth')
91
if os.path.islink(latest_path):
92
os.remove(latest_path)
93
os.symlink(os.path.join(output_dir, filename), latest_path)
94
95
if is_best and 'state_dict' in states.keys():
96
torch.save(states['state_dict'].module, os.path.join(output_dir, 'model_best.pth'))
97
98
99