Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/FaceMaskOverlay/lib/core/function.py
3443 views
1
# ------------------------------------------------------------------------------
2
# Copyright (c) Microsoft
3
# Licensed under the MIT License.
4
# Created by Tianheng Cheng([email protected])
5
# ------------------------------------------------------------------------------
6
7
from __future__ import absolute_import
8
from __future__ import division
9
from __future__ import print_function
10
11
import time
12
import logging
13
14
import torch
15
import numpy as np
16
17
from .evaluation import decode_preds, compute_nme
18
19
logger = logging.getLogger(__name__)
20
21
22
class AverageMeter(object):
23
"""Computes and stores the average and current value"""
24
def __init__(self):
25
self.val = 0
26
self.avg = 0
27
self.sum = 0
28
self.count = 0
29
self.reset()
30
31
def reset(self):
32
self.val = 0
33
self.avg = 0
34
self.sum = 0
35
self.count = 0
36
37
def update(self, val, n=1):
38
self.val = val
39
self.sum += val * n
40
self.count += n
41
self.avg = self.sum / self.count
42
43
44
def train(config, train_loader, model, critertion, optimizer,
45
epoch, writer_dict):
46
47
batch_time = AverageMeter()
48
data_time = AverageMeter()
49
losses = AverageMeter()
50
51
model.train()
52
nme_count = 0
53
nme_batch_sum = 0
54
55
end = time.time()
56
57
for i, (inp, target, meta) in enumerate(train_loader):
58
# measure data time
59
data_time.update(time.time()-end)
60
61
# compute the output
62
output = model(inp)
63
target = target.cuda(non_blocking=True)
64
65
loss = critertion(output, target)
66
67
# NME
68
score_map = output.data.cpu()
69
preds = decode_preds(score_map, meta['center'], meta['scale'], [64, 64])
70
71
nme_batch = compute_nme(preds, meta)
72
nme_batch_sum = nme_batch_sum + np.sum(nme_batch)
73
nme_count = nme_count + preds.size(0)
74
75
# optimize
76
optimizer.zero_grad()
77
loss.backward()
78
optimizer.step()
79
80
losses.update(loss.item(), inp.size(0))
81
82
batch_time.update(time.time()-end)
83
if i % config.PRINT_FREQ == 0:
84
msg = 'Epoch: [{0}][{1}/{2}]\t' \
85
'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
86
'Speed {speed:.1f} samples/s\t' \
87
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
88
'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(
89
epoch, i, len(train_loader), batch_time=batch_time,
90
speed=inp.size(0)/batch_time.val,
91
data_time=data_time, loss=losses)
92
logger.info(msg)
93
94
if writer_dict:
95
writer = writer_dict['writer']
96
global_steps = writer_dict['train_global_steps']
97
writer.add_scalar('train_loss', losses.val, global_steps)
98
writer_dict['train_global_steps'] = global_steps + 1
99
100
end = time.time()
101
nme = nme_batch_sum / nme_count
102
msg = 'Train Epoch {} time:{:.4f} loss:{:.4f} nme:{:.4f}'\
103
.format(epoch, batch_time.avg, losses.avg, nme)
104
logger.info(msg)
105
106
107
def validate(config, val_loader, model, criterion, epoch, writer_dict):
108
batch_time = AverageMeter()
109
data_time = AverageMeter()
110
111
losses = AverageMeter()
112
113
num_classes = config.MODEL.NUM_JOINTS
114
predictions = torch.zeros((len(val_loader.dataset), num_classes, 2))
115
116
model.eval()
117
118
nme_count = 0
119
nme_batch_sum = 0
120
count_failure_008 = 0
121
count_failure_010 = 0
122
end = time.time()
123
124
with torch.no_grad():
125
for i, (inp, target, meta) in enumerate(val_loader):
126
data_time.update(time.time() - end)
127
output = model(inp)
128
target = target.cuda(non_blocking=True)
129
130
score_map = output.data.cpu()
131
# loss
132
loss = criterion(output, target)
133
134
preds = decode_preds(score_map, meta['center'], meta['scale'], [64, 64])
135
# NME
136
nme_temp = compute_nme(preds, meta)
137
# Failure Rate under different threshold
138
failure_008 = (nme_temp > 0.08).sum()
139
failure_010 = (nme_temp > 0.10).sum()
140
count_failure_008 += failure_008
141
count_failure_010 += failure_010
142
143
nme_batch_sum += np.sum(nme_temp)
144
nme_count = nme_count + preds.size(0)
145
for n in range(score_map.size(0)):
146
predictions[meta['index'][n], :, :] = preds[n, :, :]
147
148
losses.update(loss.item(), inp.size(0))
149
150
# measure elapsed time
151
batch_time.update(time.time() - end)
152
end = time.time()
153
154
nme = nme_batch_sum / nme_count
155
failure_008_rate = count_failure_008 / nme_count
156
failure_010_rate = count_failure_010 / nme_count
157
158
msg = 'Test Epoch {} time:{:.4f} loss:{:.4f} nme:{:.4f} [008]:{:.4f} ' \
159
'[010]:{:.4f}'.format(epoch, batch_time.avg, losses.avg, nme,
160
failure_008_rate, failure_010_rate)
161
logger.info(msg)
162
163
if writer_dict:
164
writer = writer_dict['writer']
165
global_steps = writer_dict['valid_global_steps']
166
writer.add_scalar('valid_loss', losses.avg, global_steps)
167
writer.add_scalar('valid_nme', nme, global_steps)
168
writer_dict['valid_global_steps'] = global_steps + 1
169
170
return nme, predictions
171
172
173
def inference(config, data_loader, model):
174
batch_time = AverageMeter()
175
data_time = AverageMeter()
176
losses = AverageMeter()
177
178
num_classes = config.MODEL.NUM_JOINTS
179
predictions = torch.zeros((len(data_loader.dataset), num_classes, 2))
180
181
model.eval()
182
183
nme_count = 0
184
nme_batch_sum = 0
185
count_failure_008 = 0
186
count_failure_010 = 0
187
end = time.time()
188
189
with torch.no_grad():
190
for i, (inp, target, meta) in enumerate(data_loader):
191
data_time.update(time.time() - end)
192
output = model(inp)
193
score_map = output.data.cpu()
194
preds = decode_preds(score_map, meta['center'], meta['scale'], [64, 64])
195
196
# NME
197
nme_temp = compute_nme(preds, meta)
198
199
failure_008 = (nme_temp > 0.08).sum()
200
failure_010 = (nme_temp > 0.10).sum()
201
count_failure_008 += failure_008
202
count_failure_010 += failure_010
203
204
nme_batch_sum += np.sum(nme_temp)
205
nme_count = nme_count + preds.size(0)
206
for n in range(score_map.size(0)):
207
predictions[meta['index'][n], :, :] = preds[n, :, :]
208
209
# measure elapsed time
210
batch_time.update(time.time() - end)
211
end = time.time()
212
213
nme = nme_batch_sum / nme_count
214
failure_008_rate = count_failure_008 / nme_count
215
failure_010_rate = count_failure_010 / nme_count
216
217
msg = 'Test Results time:{:.4f} loss:{:.4f} nme:{:.4f} [008]:{:.4f} ' \
218
'[010]:{:.4f}'.format(batch_time.avg, losses.avg, nme,
219
failure_008_rate, failure_010_rate)
220
logger.info(msg)
221
222
return nme, predictions
223
224
225
226
227