Path: blob/master/FaceMaskOverlay/lib/utils/utils.py
3443 views
# ------------------------------------------------------------------------------1# Copyright (c) Microsoft2# Licensed under the MIT License.3# Written by Bin Xiao ([email protected])4# Modified by Ke Sun ([email protected]), Tianheng Cheng([email protected])5# ------------------------------------------------------------------------------67from __future__ import absolute_import8from __future__ import division9from __future__ import print_function1011import os12import logging13import time14from pathlib import Path1516import torch17import torch.optim as optim181920def create_logger(cfg, cfg_name, phase='train'):21root_output_dir = Path(cfg.OUTPUT_DIR)22# set up logger23if not root_output_dir.exists():24print('=> creating {}'.format(root_output_dir))25root_output_dir.mkdir()2627dataset = cfg.DATASET.DATASET28model = cfg.MODEL.NAME29cfg_name = os.path.basename(cfg_name).split('.')[0]3031final_output_dir = root_output_dir / dataset / cfg_name3233print('=> creating {}'.format(final_output_dir))34final_output_dir.mkdir(parents=True, exist_ok=True)3536time_str = time.strftime('%Y-%m-%d-%H-%M')37log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase)38final_log_file = final_output_dir / log_file39head = '%(asctime)-15s %(message)s'40logging.basicConfig(filename=str(final_log_file),41format=head)42logger = logging.getLogger()43logger.setLevel(logging.INFO)44console = logging.StreamHandler()45logging.getLogger('').addHandler(console)4647tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \48(cfg_name + '_' + time_str)49print('=> creating {}'.format(tensorboard_log_dir))50tensorboard_log_dir.mkdir(parents=True, exist_ok=True)5152return logger, str(final_output_dir), str(tensorboard_log_dir)535455def get_optimizer(cfg, model):56optimizer = None57if cfg.TRAIN.OPTIMIZER == 'sgd':58optimizer = optim.SGD(59filter(lambda p: p.requires_grad, model.parameters()),60lr=cfg.TRAIN.LR,61momentum=cfg.TRAIN.MOMENTUM,62weight_decay=cfg.TRAIN.WD,63nesterov=cfg.TRAIN.NESTEROV64)65elif cfg.TRAIN.OPTIMIZER == 'adam':66optimizer = optim.Adam(67filter(lambda p: p.requires_grad, model.parameters()),68lr=cfg.TRAIN.LR69)70elif cfg.TRAIN.OPTIMIZER == 'rmsprop':71optimizer = optim.RMSprop(72filter(lambda p: p.requires_grad, model.parameters()),73lr=cfg.TRAIN.LR,74momentum=cfg.TRAIN.MOMENTUM,75weight_decay=cfg.TRAIN.WD,76alpha=cfg.TRAIN.RMSPROP_ALPHA,77centered=cfg.TRAIN.RMSPROP_CENTERED78)7980return optimizer818283def save_checkpoint(states, predictions, is_best,84output_dir, filename='checkpoint.pth'):85preds = predictions.cpu().data.numpy()86torch.save(states, os.path.join(output_dir, filename))87torch.save(preds, os.path.join(output_dir, 'current_pred.pth'))8889latest_path = os.path.join(output_dir, 'latest.pth')90if os.path.islink(latest_path):91os.remove(latest_path)92os.symlink(os.path.join(output_dir, filename), latest_path)9394if is_best and 'state_dict' in states.keys():95torch.save(states['state_dict'].module, os.path.join(output_dir, 'model_best.pth'))96979899