Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/FaceMaskOverlay/tools/test.py
3142 views
1
# ------------------------------------------------------------------------------
2
# Copyright (c) Microsoft
3
# Licensed under the MIT License.
4
# Created by Tianheng Cheng([email protected])
5
# ------------------------------------------------------------------------------
6
7
import os
8
import pprint
9
import argparse
10
11
import torch
12
import torch.nn as nn
13
import torch.backends.cudnn as cudnn
14
from torch.utils.data import DataLoader
15
import sys
16
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
17
import lib.models as models
18
from lib.config import config, update_config
19
from lib.utils import utils
20
from lib.datasets import get_dataset
21
from lib.core import function
22
23
24
def parse_args():
25
26
parser = argparse.ArgumentParser(description='Train Face Alignment')
27
28
parser.add_argument('--cfg', help='experiment configuration filename',
29
required=True, type=str)
30
parser.add_argument('--model-file', help='model parameters', required=True, type=str)
31
32
args = parser.parse_args()
33
update_config(config, args)
34
return args
35
36
37
def main():
38
39
args = parse_args()
40
41
logger, final_output_dir, tb_log_dir = \
42
utils.create_logger(config, args.cfg, 'test')
43
44
logger.info(pprint.pformat(args))
45
logger.info(pprint.pformat(config))
46
47
cudnn.benchmark = config.CUDNN.BENCHMARK
48
cudnn.determinstic = config.CUDNN.DETERMINISTIC
49
cudnn.enabled = config.CUDNN.ENABLED
50
51
config.defrost()
52
config.MODEL.INIT_WEIGHTS = False
53
config.freeze()
54
model = models.get_face_alignment_net(config)
55
56
gpus = list(config.GPUS)
57
model = nn.DataParallel(model, device_ids=gpus).cuda()
58
59
# load model
60
state_dict = torch.load(args.model_file)
61
if 'state_dict' in state_dict.keys():
62
state_dict = state_dict['state_dict']
63
model.load_state_dict(state_dict)
64
else:
65
model.module.load_state_dict(state_dict)
66
67
dataset_type = get_dataset(config)
68
69
test_loader = DataLoader(
70
dataset=dataset_type(config,
71
is_train=False),
72
batch_size=config.TEST.BATCH_SIZE_PER_GPU*len(gpus),
73
shuffle=False,
74
num_workers=config.WORKERS,
75
pin_memory=config.PIN_MEMORY
76
)
77
78
nme, predictions = function.inference(config, test_loader, model)
79
80
torch.save(predictions, os.path.join(final_output_dir, 'predictions.pth'))
81
82
83
if __name__ == '__main__':
84
main()
85
86
87